@@ -97,10 +97,36 @@ function unflatten_long_ops(op, N=4)
97
97
Rewriters. Fixpoint (Rewriters. Postwalk (Rewriters. Chain ([rule1, rule2])))(op)
98
98
end
99
99
100
+ struct Let
101
+ eqs:: Vector
102
+ body
103
+ end
104
+
105
+ function observed_let (eqs)
106
+ process -> ex -> begin
107
+ lhss = map (eq-> process (eq. lhs), eqs)
108
+ rhss = map (eq-> process (eq. rhs), eqs)
109
+ letexpr = Expr (:let )
110
+ assignments = quote end
111
+ for (l, r) in zip (lhss, rhss)
112
+ push! (assignments. args, :($ l = $ r))
113
+ end
114
+ push! (letexpr. args, assignments)
115
+ push! (letexpr. args, ex)
116
+ letexpr
117
+ end
118
+ end
119
+
120
+ function _build_function (target:: JuliaTarget , op:: Let , args... ; conv= toexpr, kw... )
121
+ _build_function (target, op. body, args... ;
122
+ inner_let = observed_let (op. eqs), kw... )
123
+ end
124
+
100
125
# Scalar output
101
126
function _build_function (target:: JuliaTarget , op, args... ;
102
127
conv = toexpr, expression = Val{true },
103
128
checkbounds = false ,
129
+ inner_let = nothing ,
104
130
linenumbers = true , headerfun= addheader)
105
131
106
132
argnames = [gensym (:MTKArg ) for i in 1 : length (args)]
@@ -109,12 +135,18 @@ function _build_function(target::JuliaTarget, op, args...;
109
135
process = unflatten_long_ops∘ (x-> substitute (x, symsdict, fold= false ))
110
136
ls = reduce (vcat,conv .(first .(arg_pairs)))
111
137
rs = reduce (vcat,last .(arg_pairs))
112
- var_eqs = Expr (:(= ), ModelingToolkit. build_expr (:tuple , ls), ModelingToolkit. build_expr (:tuple , conv .(process .(rs))))
138
+ var_eqs = Expr (:(= ), build_expr (:tuple , ls), build_expr (:tuple , conv .(process .(rs))))
139
+
140
+ if inner_let != = nothing
141
+ inner_let_expr = inner_let (conv ∘ process)
142
+ else
143
+ inner_let_expr = identity
144
+ end
113
145
114
146
fname = gensym (:ModelingToolkitFunction )
115
147
op = process (op)
116
148
out_expr = conv (substitute (op, symsdict, fold= false ))
117
- let_expr = Expr (:let , var_eqs, Expr (:block , out_expr))
149
+ let_expr = Expr (:let , var_eqs, Expr (:block , inner_let_expr ( out_expr) ))
118
150
bounds_block = checkbounds ? let_expr : :(@inbounds begin $ let_expr end )
119
151
120
152
fargs = Expr (:tuple ,argnames... )
@@ -218,6 +250,7 @@ Special Keyword Argumnets:
218
250
"""
219
251
function _build_function (target:: JuliaTarget , rhss:: AbstractArray , args... ;
220
252
conv = toexpr, expression = Val{true },
253
+ inner_let = nothing ,
221
254
checkbounds = false ,
222
255
linenumbers = false , multithread= nothing ,
223
256
headerfun = addheader, outputidxs= nothing ,
@@ -235,6 +268,12 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
235
268
arg_pairs = map ((x,y)-> vars_to_pairs (x,y, symsdict), argnames, args)
236
269
process = unflatten_long_ops∘ (x-> substitute (x, symsdict, fold= false ))
237
270
271
+ if inner_let != = nothing
272
+ inner_let_expr = inner_let (conv ∘ process)
273
+ else
274
+ inner_let_expr = identity
275
+ end
276
+
238
277
ls = reduce (vcat,conv .(first .(arg_pairs)))
239
278
rs = reduce (vcat,last .(arg_pairs))
240
279
var_eqs = Expr (:(= ), ModelingToolkit. build_expr (:tuple , ls), ModelingToolkit. build_expr (:tuple , conv .(process .(rs))))
@@ -440,9 +479,10 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
440
479
end
441
480
end : arr_sys_expr
442
481
443
- let_expr = Expr (:let , var_eqs, tuple_sys_expr)
444
- arr_let_expr = Expr (:let , var_eqs, arr_sys_expr)
445
- bounds_block = checkbounds ? let_expr : :(@inbounds begin $ let_expr end )
482
+ arr_let_expr = Expr (:let , var_eqs, inner_let_expr (arr_sys_expr))
483
+ idx = findfirst (x-> Meta. isexpr (x, :let ), ip_let_expr. args)
484
+ ip_let_expr. args[idx]. args[2 ] = inner_let_expr (ip_let_expr. args[idx]. args[2 ])
485
+
446
486
oop_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $ arr_let_expr end )
447
487
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ ip_let_expr end )
448
488
0 commit comments