Skip to content

Commit 22778bb

Browse files
authored
Merge pull request #406 from JuliaSymbolics/s/no-bindings
WIP: create_bindings in DestructuredArgs
2 parents 191abe3 + be51fe3 commit 22778bb

File tree

3 files changed

+108
-37
lines changed

3 files changed

+108
-37
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SymbolicUtils"
22
uuid = "d1185830-fcd6-423d-90d6-eec64667417b"
33
authors = ["Shashi Gowda"]
4-
version = "0.18.2"
4+
version = "0.19"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/code.jl

Lines changed: 80 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ import SymbolicUtils: @matchable, Sym, Term, istree, operation, arguments,
1515
##== state management ==##
1616

1717
struct NameState
18-
symbolify::Dict{Any, Symbol}
18+
rewrites::Dict{Any, Any}
1919
end
20-
NameState() = NameState(Dict{Any, Symbol}())
21-
function union_symbolify!(n, ts)
20+
NameState() = NameState(Dict{Any, Any}())
21+
function union_rewrites!(n, ts)
2222
for t in ts
2323
n[t] = Symbol(string(t))
2424
end
@@ -34,7 +34,7 @@ function Base.get(st::LazyState)
3434
s === nothing ? getfield(st, :ref)[] = NameState() : s
3535
end
3636

37-
@inline Base.getproperty(st::LazyState, f::Symbol) = getproperty(get(st), f)
37+
@inline Base.getproperty(st::LazyState, f::Symbol) = f==:symbolify ? getproperty(st, :rewrites) : getproperty(get(st), f)
3838

3939
##========================##
4040

@@ -75,7 +75,10 @@ when `y(t)` is itself the argument of a function rather than `y`.
7575
7676
"""
7777
toexpr(x) = toexpr(x, LazyState())
78-
toexpr(s::Sym, st) = nameof(s)
78+
function toexpr(s::Sym, st)
79+
s′ = substitute_name(s, st)
80+
s′ isa Sym ? nameof(s′) : toexpr(s′, st)
81+
end
7982

8083

8184
@matchable struct Assignment
@@ -102,7 +105,7 @@ toexpr(a::Assignment, st) = :($(toexpr(a.lhs, st)) = $(toexpr(a.rhs, st)))
102105
function_to_expr(op, args, st) = nothing
103106

104107
function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st)
105-
out = get(st.symbolify, O, nothing)
108+
out = get(st.rewrites, O, nothing)
106109
out === nothing || return out
107110
args = map(Base.Fix2(toexpr, st), arguments(O))
108111
if length(args) >= 3 && symtype(O) <: Number
@@ -135,18 +138,27 @@ function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st)
135138
:($(toexpr(args[1], st)) ? $(toexpr(args[2], st)) : $(toexpr(args[3], st)))
136139
end
137140

138-
function_to_expr(::Sym, O, st) = get(st.symbolify, O, nothing)
141+
function_to_expr(::Sym, O, st) = get(st.rewrites, O, nothing)
139142

140143
toexpr(O::Expr, st) = O
141144

145+
function substitute_name(O, st)
146+
if (issym(O) || istree(O)) && haskey(st.rewrites, O)
147+
st.rewrites[O]
148+
else
149+
O
150+
end
151+
end
152+
142153
function toexpr(O, st)
154+
O = substitute_name(O, st)
143155
!istree(O) && return O
144156
op = operation(O)
145157
expr′ = function_to_expr(op, O, st)
146158
if expr′ !== nothing
147159
return expr′
148160
else
149-
haskey(st.symbolify, O) && return st.symbolify[O]
161+
!istree(O) && return O
150162
args = arguments(O)
151163
return Expr(:call, toexpr(op, st), map(x->toexpr(x, st), args)...)
152164
end
@@ -158,10 +170,15 @@ end
158170
inds
159171
name
160172
inbounds::Bool
173+
create_bindings::Bool
161174
end
162175

163-
function DestructuredArgs(elems, name=gensym("arg"); inds=eachindex(elems), inbounds=false)
164-
DestructuredArgs(elems, inds, name, inbounds)
176+
function DestructuredArgs(elems, name=nothing; inds=eachindex(elems), inbounds=false, create_bindings=true)
177+
if name === nothing
178+
# I'm sorry if you get a hash collision here lol
179+
name = Symbol("##arg#", hash((elems, inds, inbounds, create_bindings)))
180+
end
181+
DestructuredArgs(elems, inds, name, inbounds, create_bindings)
165182
end
166183

167184
"""
@@ -176,76 +193,105 @@ components. See example in `Func` for more information.
176193
DestructuredArgs
177194

178195
toexpr(x::DestructuredArgs, st) = toexpr(x.name, st)
179-
get_symbolify(args::DestructuredArgs) = ()
180-
function get_symbolify(args::Union{AbstractArray, Tuple})
181-
cflatten(map(get_symbolify, args))
196+
get_rewrites(args::DestructuredArgs) = ()
197+
function get_rewrites(args::Union{AbstractArray, Tuple})
198+
cflatten(map(get_rewrites, args))
182199
end
183-
get_symbolify(x) = istree(x) ? (x,) : ()
200+
get_rewrites(x) = istree(x) ? (x,) : ()
184201
cflatten(x) = Iterators.flatten(x) |> collect
185202

203+
# Used in Symbolics
204+
Base.@deprecate_binding get_symbolify get_rewrites
205+
186206
function get_assignments(d::DestructuredArgs, st)
187207
name = toexpr(d, st)
188208
map(d.inds, d.elems) do i, a
189209
ex = (i isa Symbol ? :($name.$i) : :($name[$i]))
190-
ex = d.inbounds ? :(@inbounds($ex)) : ex
210+
ex = d.inbounds && d.create_bindings ? :(@inbounds($ex)) : ex
191211
a ex
192212
end
193213
end
194214

195215
@matchable struct Let
196216
pairs::Vector{Union{Assignment,DestructuredArgs}} # an iterator of pairs, ordered
197217
body
218+
let_block::Bool
198219
end
199220

200221
"""
201-
Let(assignments, body)
222+
Let(assignments, body[, let_block])
202223
203224
A Let block.
204225
205226
- `assignments` is a vector of `Assignment`s
206227
- `body` is the body of the let block
228+
- `let_block` boolean (default=true) -- do not create a let block if false.
207229
"""
208-
Let
230+
Let(assignments, body) = Let(assignments, body, true)
209231

210232
function toexpr(l::Let, st)
211233
if all(x->x isa Assignment && !(x.lhs isa DestructuredArgs), l.pairs)
212234
dargs = l.pairs
213235
else
214-
dargs = map(l.pairs) do x
236+
assignments = []
237+
for x in l.pairs
215238
if x isa DestructuredArgs
216-
get_assignments(x, st)
239+
if x.create_bindings
240+
append!(assignments, get_assignments(x, st))
241+
else
242+
for a in get_assignments(x, st)
243+
st.rewrites[a.lhs] = a.rhs
244+
end
245+
end
217246
elseif x isa Assignment && x.lhs isa DestructuredArgs
218-
[x.lhs.name x.rhs, get_assignments(x.lhs, st)...]
247+
if x.lhs.create_bindings
248+
push!(assignments, x.lhs.name x.rhs)
249+
append!(assignments, get_assignments(x.lhs, st))
250+
else
251+
push!(assignments, x.lhs.name x.rhs)
252+
for a in get_assignments(x.lhs, st)
253+
st.rewrites[a.lhs] = a.rhs
254+
end
255+
end
219256
else
220-
(x,)
257+
push!(assignments, x)
221258
end
222-
end |> cflatten
259+
end
223260
# expand and come back
224-
return toexpr(Let(dargs, l.body), st)
261+
return toexpr(Let(assignments, l.body, l.let_block), st)
225262
end
226263

227-
funkyargs = get_symbolify(map(lhs, dargs))
228-
union_symbolify!(st.symbolify, funkyargs)
264+
funkyargs = get_rewrites(map(lhs, dargs))
265+
union_rewrites!(st.rewrites, funkyargs)
229266

230-
Expr(:let,
231-
Expr(:block, map(p->toexpr(p, st), dargs)...),
232-
toexpr(l.body, st))
267+
bindings = map(p->toexpr(p, st), dargs)
268+
l.let_block ? Expr(:let,
269+
Expr(:block, bindings...),
270+
toexpr(l.body, st)) : Expr(:block,
271+
bindings...,
272+
toexpr(l.body, st))
233273
end
234274

235275
@matchable struct Func
236276
args::Vector
237277
kwargs
238278
body
279+
pre::Vector
239280
end
240281

282+
Func(args, kwargs, body) = Func(args, kwargs, body, [])
283+
241284
"""
242-
Func(args, kwargs, body)
285+
Func(args, kwargs, body[, pre])
243286
244287
A function.
245288
246289
- `args` is a vector of expressions
247290
- `kwargs` is a vector of `Assignment`s
248291
- `body` is the body of the function
292+
- `pre` a vector of expressions to be prepended to the function body,
293+
for example, it could be `[Expr(:meta, :inline), Expr(:meta, :propagate_inbounds)]`
294+
to create an `@inline @propagate_inbounds` function definition.
249295
250296
**Special features in `args`**:
251297
@@ -291,21 +337,23 @@ Func
291337
toexpr_kw(f, st) = Expr(:kw, toexpr(f, st).args...)
292338

293339
function toexpr(f::Func, st)
294-
funkyargs = get_symbolify(vcat(f.args, map(lhs, f.kwargs)))
295-
union_symbolify!(st.symbolify, funkyargs)
340+
funkyargs = get_rewrites(vcat(f.args, map(lhs, f.kwargs)))
341+
union_rewrites!(st.rewrites, funkyargs)
296342
dargs = filter(x->x isa DestructuredArgs, f.args)
297343
if !isempty(dargs)
298-
body = Let(dargs, f.body)
344+
body = Let(dargs, f.body, false)
299345
else
300346
body = f.body
301347
end
302348
if isempty(f.kwargs)
303349
:(function ($(map(x->toexpr(x, st), f.args)...),)
350+
$(f.pre...)
304351
$(toexpr(body, st))
305352
end)
306353
else
307354
:(function ($(map(x->toexpr(x, st), f.args)...),;
308355
$(map(x->toexpr_kw(x, st), f.kwargs)...))
356+
$(f.pre...)
309357
$(toexpr(body, st))
310358
end)
311359
end

test/code.jl

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
2222
@test toexpr(x(t)+y(t)) == :($(+)(x(t), y(t)))
2323
@test toexpr(x(t)+y(t)+x(t+1)) == :($(+)($(+)(x(t), y(t)), x($(+)(1, t))))
2424
s = LazyState()
25-
Code.union_symbolify!(s.symbolify, [x(t), y(t)])
25+
Code.union_rewrites!(s.rewrites, [x(t), y(t)])
2626
@test toexpr(x(t)+y(t)+x(t+1), s) == :($(+)($(+)(var"x(t)", var"y(t)"), x($(+)(1, t))))
2727

2828
ex = :(let a = 3, b = $(+)(1,a)
@@ -43,13 +43,34 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
4343
DestructuredArgs((a, b), :params)], [],
4444
x(t+1) + x(t) + a + b)),
4545
:(function (state, params)
46-
let x = state[1], var"x(t)" = state[2], a = params[1], b = params[2]
46+
begin
47+
x = state[1]
48+
var"x(t)" = state[2]
49+
a = params[1]
50+
b = params[2]
4751
$(+)($(+)($(+)(a, b), var"x(t)"), x($(+)(1, t)))
4852
end
4953
end))
5054

55+
test_repr(toexpr(Func([DestructuredArgs([x, x(t)], :state, create_bindings=false),
56+
DestructuredArgs((a, b), :params, create_bindings=false)], [],
57+
x(t+1) + x(t) + a + b)),
58+
:(function (state, params)
59+
begin
60+
$(+)($(+)($(+)(params[1], params[2]), $getindex(state, 2)), state[1]($(+)(1, t)))
61+
end
62+
end))
63+
64+
65+
test_repr(toexpr(Func([],[],:(rand()), [Expr(:meta, :inline)])),
66+
:(function ()
67+
$(Expr(:meta, :inline))
68+
rand()
69+
end))
70+
5171
ex = toexpr(Func([DestructuredArgs([x, x(t)], :state, inbounds=true)], [], x(t+1) + x(t)))
52-
for e ex.args[2].args[3].args[1].args
72+
ex = Base.remove_linenums!(ex)
73+
for e ex.args[2].args[1].args[1:2]
5374
@test e.args[2].head == :macrocall
5475
end
5576

@@ -89,7 +110,9 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
89110
test_repr(toexpr(Func([DestructuredArgs([a,b],c,inds=[:a, :b])], [],
90111
a + b)),
91112
:(function (c,)
92-
let a = c.a, b = c.b
113+
begin
114+
a = c.a
115+
b = c.b
93116
$(+)(a, b)
94117
end
95118
end))

0 commit comments

Comments
 (0)