Skip to content

Commit 673146c

Browse files
committed
feat: n-arity strings
1 parent cda0896 commit 673146c

File tree

1 file changed

+37
-34
lines changed

1 file changed

+37
-34
lines changed

src/Strings.jl

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,43 @@ module StringsModule
22

33
using ..UtilsModule: deprecate_varmap
44
using ..OperatorEnumModule: AbstractOperatorEnum
5-
using ..NodeModule: AbstractExpressionNode, tree_mapreduce
5+
using ..NodeModule: AbstractExpressionNode, tree_mapreduce, max_degree
66

77
function dispatch_op_name(
88
::Val{deg}, ::Nothing, idx, pretty::Bool
99
)::Vector{Char} where {deg}
10-
return vcat(
11-
collect(deg == 1 ? "unary_operator[" : "binary_operator["),
12-
collect(string(idx)),
13-
[']'],
14-
)
10+
return vcat(collect(
11+
if deg == 1
12+
"unary_operator["
13+
elseif deg == 2
14+
"binary_operator["
15+
else
16+
"operator_deg$deg["
17+
end,
18+
), collect(string(idx)), [']'])
1519
end
1620
function dispatch_op_name(
1721
::Val{deg}, operators::AbstractOperatorEnum, idx, pretty::Bool
1822
) where {deg}
19-
op = if deg == 1
20-
operators.unaops[idx]
21-
else
22-
operators.binops[idx]
23-
end
23+
op = operators[deg][idx]
2424
return collect((pretty ? get_pretty_op_name(op) : get_op_name(op))::String)
2525
end
2626

27+
struct OpNameDispatcher{D,O<:AbstractOperatorEnum} <: Function
28+
operators::O
29+
pretty::Bool
30+
end
31+
@generated function (f::OpNameDispatcher{D,O})(branch) where {D,O}
32+
return quote
33+
degree = branch.degree
34+
Base.Cartesian.@nif(
35+
$D,
36+
d -> d == degree,
37+
d -> dispatch_op_name(Val(d), f.operators, branch.op, f.pretty),
38+
)::Vector{Char}
39+
end
40+
end
41+
2742
const OP_NAME_CACHE = (; x=Dict{UInt64,String}(), lock=Threads.SpinLock())
2843

2944
function get_op_name(op::F) where {F}
@@ -89,36 +104,30 @@ function string_variable(feature, variable_names)
89104
end
90105

91106
# Vector of chars is faster than strings, so we use that.
92-
function combine_op_with_inputs(op, l, r)::Vector{Char}
93-
if first(op) in ('+', '-', '*', '/', '^', '.', '>', '<', '=') || op == "!="
107+
function combine_op_with_inputs(op, args::Vararg{Any,D})::Vector{Char} where {D}
108+
if D == 2 && (first(op) in ('+', '-', '*', '/', '^', '.', '>', '<', '=') || op == "!=")
94109
# "(l op r)"
95110
out = ['(']
96-
append!(out, l)
111+
append!(out, args[1])
97112
push!(out, ' ')
98113
append!(out, op)
99114
push!(out, ' ')
100-
append!(out, r)
115+
append!(out, args[2])
101116
push!(out, ')')
102117
else
103118
# "op(l, r)"
104119
out = copy(op)
105120
push!(out, '(')
106-
append!(out, strip_brackets(l))
107-
push!(out, ',')
108-
push!(out, ' ')
109-
append!(out, strip_brackets(r))
121+
for i in 1:(D - 1)
122+
append!(out, strip_brackets(args[i]))
123+
push!(out, ',')
124+
push!(out, ' ')
125+
end
126+
append!(out, strip_brackets(args[D]))
110127
push!(out, ')')
111128
return out
112129
end
113130
end
114-
function combine_op_with_inputs(op, l)
115-
# "op(l)"
116-
out = copy(op)
117-
push!(out, '(')
118-
append!(out, strip_brackets(l))
119-
push!(out, ')')
120-
return out
121-
end
122131

123132
"""
124133
string_tree(
@@ -169,13 +178,7 @@ function string_tree(
169178
collect(f_variable(leaf.feature, variable_names))::Vector{Char}
170179
end
171180
end,
172-
let operators = operators
173-
(branch,) -> if branch.degree == 1
174-
dispatch_op_name(Val(1), operators, branch.op, pretty)::Vector{Char}
175-
else
176-
dispatch_op_name(Val(2), operators, branch.op, pretty)::Vector{Char}
177-
end
178-
end,
181+
OpNameDispatcher{max_degree(tree),typeof(operators)}(operators, pretty),
179182
combine_op_with_inputs,
180183
tree,
181184
Vector{Char};

0 commit comments

Comments
 (0)