@@ -52,14 +52,19 @@ function build_function(args...;target = JuliaTarget(),kwargs...)
52
52
_build_function (target,args... ;kwargs... )
53
53
end
54
54
55
- function unflatten_long_ops (op , N= 4 )
56
- rule1 = @rule (( + )(( ~~ x)) => length (~~ x) > N ?
57
- + ( + (( ~~ x)[ 1 : N] . .. ) + ( + )(( ~~ x)[N + 1 : end ] . .. )) : nothing )
58
- rule2 = @rule (( * )(( ~~ x)) => length ( ~~ x) > N ?
59
- * ( * (( ~~ x)[ 1 : N] . .. ) * ( * )(( ~~ x)[N + 1 : end ] . .. )) : nothing )
55
+ function unflatten_args (f, args , N= 4 )
56
+ length (args) < N && return Term {Real} (f, args)
57
+ unflatten_args (f, [ Term {Real} (f, group )
58
+ for group in Iterators . partition (args, N)], N)
59
+ end
60
60
61
+ function unflatten_long_ops (op, N= 4 )
61
62
op = value (op)
62
- Rewriters. Fixpoint (Rewriters. Postwalk (Rewriters. Chain ([rule1, rule2])))(op)
63
+ ! istree (op) && return Num (op)
64
+ rule1 = @rule ((+ )(~~ x) => length (~~ x) > N ? unflatten_args (+ , ~~ x, 4 ) : nothing )
65
+ rule2 = @rule ((* )(~~ x) => length (~~ x) > N ? unflatten_args (* , ~~ x, 4 ) : nothing )
66
+
67
+ Num (Rewriters. Postwalk (Rewriters. Chain ([rule1, rule2]))(op))
63
68
end
64
69
65
70
@@ -76,7 +81,7 @@ function _build_function(target::JuliaTarget, op, args...;
76
81
linenumbers = true )
77
82
78
83
dargs = map (destructure_arg, [args... ])
79
- expr = toexpr (Func (dargs, [], op ))
84
+ expr = toexpr (Func (dargs, [], unflatten_long_ops (op) ))
80
85
81
86
if expression == Val{true }
82
87
expr
@@ -247,7 +252,7 @@ function _make_array(rhss::AbstractArray, similarto)
247
252
end
248
253
end
249
254
250
- _make_array (x, similarto) = x
255
+ _make_array (x, similarto) = unflatten_long_ops (x)
251
256
252
257
# # In-place version
253
258
@@ -298,7 +303,7 @@ function _set_array(out, outputidxs, rhss::AbstractArray, checkbounds, skipzeros
298
303
end )
299
304
end
300
305
301
- _set_array (out, outputidxs, rhs, checkbounds, skipzeros) = rhs
306
+ _set_array (out, outputidxs, rhs, checkbounds, skipzeros) = unflatten_long_ops ( rhs)
302
307
303
308
304
309
function vars_to_pairs (name,vs:: Union{Tuple, AbstractArray} , symsdict= Dict ())
0 commit comments