Skip to content

Commit faf3ee2

Browse files
authored
Merge pull request #194 from JuliaSymbolics/s/larray
Code module improvements
2 parents 40fe575 + a326bf5 commit faf3ee2

File tree

6 files changed

+105
-54
lines changed

6 files changed

+105
-54
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
page/* linguist-vendored

src/code.jl

Lines changed: 77 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,49 @@ function toexpr(O, st)
100100
return toexpr(Term{Any}(inv, [ex]), st)
101101
else
102102
return toexpr(Term{Any}(^, [Term{Any}(inv, [ex]), -args[2]]), st)
103-
end
104-
elseif op === (SymbolicUtils.ifelse)
105-
return :($(toexpr(args[1], st)) ? $(toexpr(args[2], st)) : $(toexpr(args[3], st)))
106-
elseif op isa Sym && O in st.symbolify
107-
return Symbol(string(O))
108-
end
103+
end elseif op === (SymbolicUtils.ifelse) return :($(toexpr(args[1], st)) ? $(toexpr(args[2], st)) : $(toexpr(args[3], st))) elseif op isa Sym && O in st.symbolify return Symbol(string(O)) end
109104
return Expr(:call, toexpr(op, st), map(x->toexpr(x, st), args)...)
110105
end
111106

107+
# Call elements of vector arguments by their name.
108+
@matchable struct DestructuredArgs
109+
elems
110+
inds
111+
name
112+
end
113+
114+
function DestructuredArgs(elems, name=gensym("arg"); inds=eachindex(elems))
115+
DestructuredArgs(elems, inds, name)
116+
end
117+
118+
"""
119+
DestructuredArgs(elems, [name=gensym("arg")])
120+
121+
`elems` is a vector of symbols or call expressions. When it appears as an argument in
122+
`Func`, it expects a vector of the same length and de-structures the vector into its named
123+
components. See example in `Func` for more information.
124+
125+
`name` is the name to be used for the argument in the generated function Expr.
126+
"""
127+
DestructuredArgs
128+
129+
toexpr(x::DestructuredArgs, st) = toexpr(x.name, st)
130+
get_symbolify(args::DestructuredArgs) = ()
131+
function get_symbolify(args::Union{AbstractArray, Tuple})
132+
cflatten(map(get_symbolify, args))
133+
end
134+
get_symbolify(x) = istree(x) ? (x,) : ()
135+
cflatten(x) = Iterators.flatten(x) |> collect
136+
137+
function get_assignments(d::DestructuredArgs, st)
138+
name = toexpr(d, st)
139+
map(d.inds, d.elems) do i, a
140+
a (i isa Symbol ? :($name.$i) : :($name[$i]))
141+
end
142+
end
143+
112144
@matchable struct Let
113-
pairs::Vector{Assignment} # an iterator of pairs, ordered
145+
pairs::Vector{Union{Assignment,DestructuredArgs}} # an iterator of pairs, ordered
114146
body
115147
end
116148

@@ -125,13 +157,32 @@ A Let block.
125157
Let
126158

127159
function toexpr(l::Let, st)
160+
if all(x->x isa Assignment && !(x.lhs isa DestructuredArgs), l.pairs)
161+
dargs = l.pairs
162+
else
163+
dargs = map(l.pairs) do x
164+
if x isa DestructuredArgs
165+
get_assignments(x, st)
166+
elseif x isa Assignment && x.lhs isa DestructuredArgs
167+
[x.lhs.name x.rhs, get_assignments(x.lhs, st)...]
168+
else
169+
(x,)
170+
end
171+
end |> cflatten
172+
# expand and come back
173+
return toexpr(Let(dargs, l.body), st)
174+
end
175+
176+
funkyargs = get_symbolify(map(lhs, dargs))
177+
union!(st.symbolify, funkyargs)
178+
128179
Expr(:let,
129-
Expr(:block, map(p->toexpr(p, st), l.pairs)...),
180+
Expr(:block, map(p->toexpr(p, st), dargs)...),
130181
toexpr(l.body, st))
131182
end
132183

133184
@matchable struct Func
134-
args
185+
args::Vector
135186
kwargs
136187
body
137188
end
@@ -188,43 +239,12 @@ Func
188239

189240
toexpr_kw(f, st) = Expr(:kw, toexpr(f, st).args...)
190241

191-
# Call elements of vector arguments by their name.
192-
@matchable struct DestructuredArgs
193-
elems
194-
name
195-
end
196-
197-
DestructuredArgs(elems) = DestructuredArgs(elems, gensym("arg"))
198-
199-
"""
200-
DestructuredArgs(elems, [name=gensym("arg")])
201-
202-
`elems` is a vector of symbols or call expressions. When it appears as an argument in
203-
`Func`, it expects a vector of the same length and de-structures the vector into its named
204-
components. See example in `Func` for more information.
205-
206-
`name` is the name to be used for the argument in the generated function Expr.
207-
"""
208-
DestructuredArgs
209-
210-
toexpr(x::DestructuredArgs, st) = x.name
211-
get_symbolify(args::DestructuredArgs) = get_symbolify(args.elems)
212-
function get_symbolify(args::Union{AbstractArray, Tuple})
213-
cflatten(map(get_symbolify, args))
214-
end
215-
get_symbolify(x) = istree(x) ? (x,) : ()
216-
cflatten(x) = Iterators.flatten(x) |> collect
217-
218-
function get_assignments(d::DestructuredArgs, st)
219-
[a Expr(:ref, toexpr(d, st), i) for (i, a) in enumerate(d.elems)]
220-
end
221-
222242
function toexpr(f::Func, st)
223243
funkyargs = get_symbolify(vcat(f.args, map(lhs, f.kwargs)))
224-
dargs = filter(x->x isa DestructuredArgs, f.args)
225244
union!(st.symbolify, funkyargs)
245+
dargs = filter(x->x isa DestructuredArgs, f.args)
226246
if !isempty(dargs)
227-
body = Let(cflatten(map(x->get_assignments(x, st), dargs)), f.body)
247+
body = Let(dargs, f.body)
228248
else
229249
body = f.body
230250
end
@@ -324,7 +344,7 @@ function toexpr(a::MakeArray, st)
324344
$create_array($T,
325345
$elT,
326346
Val{$(size(a.elems))}(),
327-
$(toexpr.(a.elems, (st,))...),)
347+
$(map(x->toexpr(x, st), a.elems)...),)
328348
end
329349
end
330350

@@ -367,16 +387,23 @@ end
367387
end
368388

369389
## LabelledArrays
370-
@inline function create_array(A::Type{<:SLArray}, ::Nothing, d::Val{dims}, elems...) where {dims}
371-
a = create_array(SArray, nothing, d, elems...)
372-
similar_type(A, eltype(a), Size(dims))(a)
373-
end
374-
375390
@inline function create_array(A::Type{<:SLArray}, T, d::Val{dims}, elems...) where {dims}
376-
similar_type(A, T, Size(dims))(create_array(SArray, T, d, elems...))
391+
a = create_array(SArray, T, d, elems...)
392+
if nfields(dims) === ndims(A)
393+
similar_type(A, eltype(a), Size(dims))(a)
394+
else
395+
a
396+
end
377397
end
378398

379-
using SparseArrays
399+
@inline function create_array(A::Type{<:LArray}, T, d::Val{dims}, elems...) where {dims}
400+
data = create_array(Array, T, d, elems...)
401+
if nfields(dims) === ndims(A)
402+
LArray{eltype(data),nfields(dims),typeof(data),LabelledArrays.symnames(A)}(data)
403+
else
404+
data
405+
end
406+
end
380407

381408
## We use a separate type for Sparse Arrays to sidestep the need for
382409
## iszero to be defined on the expression type

src/methods.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,9 @@ end
139139
# An ifelse node, ifelse is a built-in unfortunately
140140
# So this uses IfElse.jl's ifelse that we imported
141141
function ifelse(_if::Symbolic{Bool}, _then, _else)
142-
Term{Union{symtype(_then), symtype(_else)}}(cond, Any[_if, _then, _else])
142+
Term{Union{symtype(_then), symtype(_else)}}(ifelse, Any[_if, _then, _else])
143143
end
144+
promote_symtype(::typeof(ifelse), _, ::Type{T}, ::Type{S}) where {T,S} = Union{T, S}
144145
Base.@deprecate cond(_if, _then, _else) ifelse(_if, _then, _else)
145146

146147
# Specially handle inv and literal pow

src/utils.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,15 @@ macro matchable(expr)
201201
if name isa Expr && name.head === :curly
202202
name = name.args[1]
203203
end
204-
fields = expr.args[3].args # Todo: get names
204+
fields = filter(x-> !(x isa LineNumberNode), expr.args[3].args)
205+
get_name(s::Symbol) = s
206+
get_name(e::Expr) = (@assert(e.head == :(::)); e.args[1])
207+
fields = map(get_name, fields)
205208
quote
206209
$expr
207210
SymbolicUtils.istree(::$name) = true
208211
SymbolicUtils.operation(::$name) = $name
209-
SymbolicUtils.arguments(::$name) = ($(fields...),)
212+
SymbolicUtils.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),))
213+
Base.length(x::$name) = $(length(fields) + 1)
210214
end |> esc
211215
end

test/basics.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using SymbolicUtils: Sym, FnType, Term, Add, Mul, Pow, symtype, operation, arguments
22
using SymbolicUtils
3+
using IfElse: ifelse
34
using Test
45

56
@testset "@syms" begin
@@ -86,6 +87,7 @@ end
8687

8788
@test symtype(ifelse(true, 4, 5)) == Int
8889
@test symtype(ifelse(a < 0, b, w)) == Union{Real, Complex}
90+
@test SymbolicUtils.promote_symtype(ifelse, Bool, Int, Bool) == Union{Int, Bool}
8991
@test_throws MethodError w < 0
9092
@test isequal(w == 0, Term{Bool}(==, [w, 0]))
9193
end

test/code.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using SparseArrays
88
test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linenums!(b))
99

1010
@testset "Code" begin
11-
@syms a b c d e t x(t) y(t) z(t)
11+
@syms a b c d e p q t x(t) y(t) z(t)
1212
@test toexpr(Assignment(a, b)) == :(a = b)
1313
@test toexpr(a b) == :(a = b)
1414
@test toexpr(a+b) == :($(+)(a, b))
@@ -69,6 +69,22 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
6969
toexpr(Let([a 1, b 2, :arr [1,2]],
7070
MakeArray([a,b,a+b,a/b], :arr)))
7171

72+
test_repr(toexpr(Let([DestructuredArgs([x(t),b,c], :foo) [3,3,[1,4]],
73+
DestructuredArgs([p,q], c)],
74+
x(t)+a+b+c)),
75+
:(let foo = Any[3, 3, [1, 4]],
76+
var"x(t)" = foo[1], b = foo[2], c = foo[3],
77+
p = c[1], q = c[2]
78+
$(+)(a, b, c, var"x(t)")
79+
end))
80+
81+
test_repr(toexpr(Func([DestructuredArgs([a,b],c,inds=[:a, :b])], [],
82+
a + b)),
83+
:(function (c,)
84+
let a = c.a, b = c.b
85+
$(+)(a, b)
86+
end
87+
end))
7288
@syms arr
7389

7490
@test eval(toexpr(Let([a 1, b 2, arr [1,2]],

0 commit comments

Comments
 (0)