Skip to content

Commit 60d5df0

Browse files
Merge pull request #626 from AayushSabharwal/as/symbolic-array-codegen
feat: support codegen for expressions involving arrays of symbolics
2 parents da9267e + a6929f4 commit 60d5df0

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

src/code.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import ..SymbolicUtils
1010
import ..SymbolicUtils.Rewriters
1111
import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym,
1212
symtype, sorted_arguments, metadata, isterm, term, maketerm
13+
import SymbolicIndexingInterface: symbolic_type, NotSymbolic
1314

1415
##== state management ==##
1516

@@ -169,13 +170,24 @@ function substitute_name(O, st)
169170
end
170171
end
171172

173+
function _is_array_of_symbolics(O)
174+
# O is an array, not a symbolic array, and either has a non-symbolic eltype or contains elements that are
175+
# symbolic or arrays of symbolics
176+
return O isa AbstractArray && symbolic_type(O) == NotSymbolic() &&
177+
(symbolic_type(eltype(O)) != NotSymbolic() ||
178+
any(x -> symbolic_type(x) != NotSymbolic() || _is_array_of_symbolics(x), O))
179+
end
180+
172181
function toexpr(O, st)
173182
if issym(O)
174183
O = substitute_name(O, st)
175184
return issym(O) ? nameof(O) : toexpr(O, st)
176185
end
177186
O = substitute_name(O, st)
178187

188+
if _is_array_of_symbolics(O)
189+
return toexpr(MakeArray(O, typeof(O)), st)
190+
end
179191
!iscall(O) && return O
180192
op = operation(O)
181193
expr′ = function_to_expr(op, O, st)

test/code.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,4 +219,12 @@ nanmath_st.rewrites[:nanmath] = true
219219
@test s1 == s2
220220
end
221221
end
222+
223+
let
224+
@syms a b
225+
226+
t = term(sum, [a, b, a + b, 3a + 2b, sqrt(b)]; type = Number)
227+
f = eval(toexpr(Func([a, b], [], t)))
228+
@test f(1.0, 2.0) 13.0 + sqrt(2)
229+
end
222230
end

0 commit comments

Comments
 (0)