Skip to content

Commit b3e21bf

Browse files
authored
Merge pull request #801 from SciML/s/unflatten
use unflatten_long_ops in build_function again
2 parents ae4f53c + 7d5839d commit b3e21bf

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

src/build_function.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,19 @@ function build_function(args...;target = JuliaTarget(),kwargs...)
5252
_build_function(target,args...;kwargs...)
5353
end
5454

55-
function unflatten_long_ops(op, N=4)
56-
rule1 = @rule((+)((~~x)) => length(~~x) > N ?
57-
+(+((~~x)[1:N]...) + (+)((~~x)[N+1:end]...)) : nothing)
58-
rule2 = @rule((*)((~~x)) => length(~~x) > N ?
59-
*(*((~~x)[1:N]...) * (*)((~~x)[N+1:end]...)) : nothing)
55+
function unflatten_args(f, args, N=4)
56+
length(args) < N && return Term{Real}(f, args)
57+
unflatten_args(f, [Term{Real}(f, group)
58+
for group in Iterators.partition(args, N)], N)
59+
end
6060

61+
function unflatten_long_ops(op, N=4)
6162
op = value(op)
62-
Rewriters.Fixpoint(Rewriters.Postwalk(Rewriters.Chain([rule1, rule2])))(op)
63+
!istree(op) && return Num(op)
64+
rule1 = @rule((+)(~~x) => length(~~x) > N ? unflatten_args(+, ~~x, 4) : nothing)
65+
rule2 = @rule((*)(~~x) => length(~~x) > N ? unflatten_args(*, ~~x, 4) : nothing)
66+
67+
Num(Rewriters.Postwalk(Rewriters.Chain([rule1, rule2]))(op))
6368
end
6469

6570

@@ -76,7 +81,7 @@ function _build_function(target::JuliaTarget, op, args...;
7681
linenumbers = true)
7782

7883
dargs = map(destructure_arg, [args...])
79-
expr = toexpr(Func(dargs, [], op))
84+
expr = toexpr(Func(dargs, [], unflatten_long_ops(op)))
8085

8186
if expression == Val{true}
8287
expr
@@ -247,7 +252,7 @@ function _make_array(rhss::AbstractArray, similarto)
247252
end
248253
end
249254

250-
_make_array(x, similarto) = x
255+
_make_array(x, similarto) = unflatten_long_ops(x)
251256

252257
## In-place version
253258

@@ -298,7 +303,7 @@ function _set_array(out, outputidxs, rhss::AbstractArray, checkbounds, skipzeros
298303
end)
299304
end
300305

301-
_set_array(out, outputidxs, rhs, checkbounds, skipzeros) = rhs
306+
_set_array(out, outputidxs, rhs, checkbounds, skipzeros) = unflatten_long_ops(rhs)
302307

303308

304309
function vars_to_pairs(name,vs::Union{Tuple, AbstractArray}, symsdict=Dict())

0 commit comments

Comments
 (0)