Skip to content

Commit 370e2fc

Browse files
committed
working but slow
1 parent 8fb1580 commit 370e2fc

File tree

2 files changed

+83
-83
lines changed

2 files changed

+83
-83
lines changed

src/optics.jl

Lines changed: 64 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ julia> obj = (a=1, b=2); lens=@optic _.a; val = 100;
4242
4343
julia> set(obj, lens, val)
4444
(a = 100, b = 2)
45-
```
46-
See also [`modify`](@ref).
45+
``` See also [`modify`](@ref).
4746
"""
4847
function set end
4948

@@ -346,15 +345,7 @@ Here `f` has signature `f(::Value, ::State) -> Tuple{NewValue, NewState}`.
346345
"""
347346
function modify_stateful end
348347

349-
@inline function modify_stateful(f, (obj, state), optic::Properties)
350-
let f=f, obj=obj, state=state
351-
modify_stateful_context((obj, state), optic) do _, fn, pr, st
352-
f(getfield(pr, known(fn)), st)
353-
end
354-
end
355-
end
356-
357-
@generated function modify_stateful_context(f, (obj, state1)::T, optic::Properties) where T
348+
@generated function modify_stateful(f::F, (obj, state)::T, optic::Properties) where {T,F}
358349
_modify_stateful_inner(T)
359350
end
360351

@@ -363,29 +354,29 @@ function _modify_stateful_inner(::Type{<:Tuple{O,S}}) where {O,S}
363354
modifications = []
364355
vals = Expr(:tuple)
365356
fns = fieldnames(O)
366-
local st1 = :state0
367-
local st2 = :state1
368357
for (i, fn) in enumerate(fns)
369358
v = Symbol("val$i")
370-
st1 = Symbol("state$i")
371-
st2 = Symbol("state$(i+1)")
372-
ms = if O <: Tuple
373-
:(($v, $st2) = f(obj, StaticInt{$(QuoteNode(fn))}(), props, $st1))
359+
st = if S <: ContextState
360+
if O <: Tuple
361+
:(ContextState(state.vals, obj, StaticInt{$(QuoteNode(fn))}()))
362+
else
363+
:(ContextState(state.vals, obj, StaticSymbol{$(QuoteNode(fn))}()))
364+
end
374365
else
375-
:(($v, $st2) = f(obj, StaticSymbol{$(QuoteNode(fn))}(), props, $st1))
366+
:state
376367
end
368+
ms = :(($v, state) = f(getfield(props, $(QuoteNode(fn))), $st))
377369
push!(modifications, ms)
378370
push!(vals.args, v)
379371
end
380372
patch = O <: Tuple ? vals : :(NamedTuple{$fns}($vals))
381-
Expr(:block,
382-
:(props = getproperties(obj)),
383-
modifications...,
384-
:(patch = $patch),
385-
:(new_obj = maybesetproperties($st2, obj, patch)),
386-
:(new_state = maybesetstate($st2, obj, patch)),
387-
:(return (setproperties(obj, patch), $st2)),
388-
)
373+
start = :(props = getproperties(obj))
374+
rest = MacroTools.@q begin
375+
patch = $patch
376+
new_obj = maybesetproperties(state, obj, patch)
377+
return (new_obj, state)
378+
end
379+
Expr(:block, start, modifications..., rest)
389380
end
390381

391382
maybesetproperties(state, obj, patch) = setproperties(obj, patch)
@@ -426,15 +417,10 @@ Query(; select=Any, descend=x -> true, optic=Properties()) = Query(select, desce
426417

427418
OpticStyle(::Type{<:AbstractQuery}) = SetBased()
428419

429-
struct Context{Select,Descend,Optic<:Union{ComposedOptic,Properties}} <: AbstractQuery
430-
select_condition::Select
431-
descent_condition::Descend
432-
optic::Optic
433-
end
434-
435-
436-
struct ContextState{V}
420+
struct ContextState{V,O,FN}
437421
vals::V
422+
obj::O
423+
fn::FN
438424
end
439425
struct GetAllState{V}
440426
vals::V
@@ -445,57 +431,69 @@ struct SetAllState{C,V,I}
445431
itr::I
446432
end
447433

448-
pop(x) = first(x), Base.tail(x)
449-
push(x, val) = (x..., val)
450-
push(x::GetAllState, val) = GetAllState(push(x.vals, val))
434+
const GetStates = Union{GetAllState,ContextState}
435+
436+
@inline pop(x) = first(x), Base.tail(x)
437+
@inline push(x, val) = (x..., val)
438+
@inline push(x::GetAllState, val) = GetAllState(push(x.vals, val))
439+
@inline push(x::ContextState, val) = ContextState(push(x.vals, val), nothing, nothing)
451440

452441
(q::Query)(obj) = getall(obj, q)
453442

454-
function getall(obj, q)
443+
getall(obj, q) = _getall(obj, q).vals
444+
function _getall(obj, q::Q) where Q<:Query
455445
initial_state = GetAllState(())
456-
_, final_state = modify_stateful((obj, initial_state), q) do o, s
457-
new_state = push(s, outer(q.optic, o, s))
458-
o, new_state
446+
_, final_state = let q=q
447+
modify_stateful((obj, initial_state), q) do o, s
448+
new_state = push(s, outer(q.optic, o, s))
449+
o, new_state
450+
end
459451
end
460-
return final_state.vals
452+
final_state
461453
end
462454

463-
function setall(obj, q, vals)
455+
function setall(obj, q::Q, vals) where Q<:Query
464456
initial_state = SetAllState(Unchanged(), vals, 1)
465-
final_obj, _ = modify_stateful((obj, initial_state), q) do o, s
466-
new_output = outer(q.optic, o, s)
467-
new_state = SetAllState(Changed(), s.vals, s.itr + 1)
468-
new_output, new_state
457+
final_obj, _ = let obj=obj, q=q, initial_state=initial_state
458+
modify_stateful((obj, initial_state), q) do o, s
459+
new_output = outer(q.optic, o, s)
460+
new_state = SetAllState(Changed(), s.vals, s.itr + 1)
461+
new_output, new_state
462+
end
469463
end
470464
return final_obj
471465
end
472466

473-
function context(f, obj, q)
474-
initial_state = GetAllState(())
475-
_, final_state = modify_stateful_context((obj, initial_state), Properties()) do o, fn, pr, s
476-
new_state = push(s, f(o, known(fn)))
477-
o, new_state
467+
function context(f::F, obj, q::Q) where {F,Q<:Query}
468+
initial_state = ContextState((), nothing, nothing)
469+
_, final_state = let f=f
470+
modify_stateful((obj, initial_state), q) do o, s
471+
new_state = push(s, f(s.obj, known(s.fn)))
472+
o, new_state
473+
end
478474
end
479475
return final_state.vals
480476
end
481477

482478
modify(f, obj, q::Query) = setall(obj, q, map(f, getall(obj, q)))
483479

484-
@inline function modify_stateful(f::F, (obj, state), q::Query) where F
485-
modify_stateful((obj, state), inner(q.optic)) do o, s
486-
if q.select_condition(o)
487-
f(o, s)
488-
elseif q.descent_condition(o)
489-
ds = descent_state(s)
490-
o, s = modify_stateful(f::F, (o, ds), q)
491-
o, merge_state(s, ds)
492-
else
493-
o, s
480+
@inline function modify_stateful(f::F, (obj, state), q::Q) where {F,Q<:Query}
481+
let f=f, q=q
482+
modify_stateful((obj, state), inner(q.optic)) do o, s
483+
if (q::Q).select_condition(o)
484+
(f::F)(o, s)
485+
elseif (q::Q).descent_condition(o)
486+
ds = descent_state(s)
487+
o, ns = modify_stateful(f::F, (o, ds), q::Q)
488+
o, merge_state(ds, ns)
489+
else
490+
o, s
491+
end
494492
end
495493
end
496494
end
497495

498-
maybesetproperties(state::GetAllState, obj, patch) = obj
496+
maybesetproperties(state::GetStates, obj, patch) = obj
499497
maybesetproperties(state::SetAllState, obj, patch) =
500498
maybesetproperties(state.change, state, obj, patch)
501499
maybesetproperties(::Changed, state::SetAllState, obj, patch) = setproperties(obj, patch)
@@ -516,8 +514,8 @@ anychanged(::Changed, ::Changed) = Changed()
516514
inner(optic) = optic
517515
inner(optic::ComposedOptic) = optic.inner
518516

519-
outer(optic, o, state::GetAllState) = o
520-
outer(optic::ComposedOptic, o, state::GetAllState) = optic.outer(o)
517+
outer(optic, o, state::GetStates) = o
518+
outer(optic::ComposedOptic, o, state::GetStates) = optic.outer(o)
521519
outer(optic::ComposedOptic, o, state::SetAllState) = set(o, optic.outer, state.vals[state.itr])
522520
outer(optic, o, state::SetAllState) = state.vals[state.itr]
523521

@@ -532,7 +530,7 @@ function (l::PropertyLens{field})(obj) where {field}
532530
end
533531

534532
@inline function set(obj, l::PropertyLens{field}, val) where {field}
535-
patch = (;field => val)
533+
patch = (; field => val)
536534
setproperties(obj, patch)
537535
end
538536

test/test_queries.jl

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,43 @@
11
using Accessors, Test, BenchmarkTools, Static
22
using Accessors: setall, getall, context
3-
4-
obj = (7, (a=17.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=6.0,), [1,]))
3+
obj = (7, (a=17.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=6.0,)), [1])
54
vals = (1.0, 2.0, 3.0, 4.0)
6-
75
# Fields is the default
86
q = Query(;
97
select=x -> x isa NamedTuple,
108
descend=x -> x isa Tuple,
119
optic = (Accessors.@optic _.a) Accessors.Properties()
1210
# optic = Accessors.Properties()
1311
)
14-
15-
println("getall")
1612
getall(obj, q)
13+
1714
@code_native getall(obj, q)
1815
@code_warntype getall(obj, q)
1916

2017
@benchmark getall($obj, $q)
2118
@test getall(obj, q) == (17.0, 6.0)
2219

20+
# using ProfileView, Cthulhu
21+
# @descend getall(obj, q)
22+
# f(obj, q) = for i in 1:10000000 getall(obj, q) end
23+
# @profview f(obj, q)
24+
2325
missings_obj = (a=missing, b=1, c=(d=missing, e=(f=missing, g=2)))
2426
@test getall(missings_obj, Query(ismissing)) === (missing, missing, missing)
2527
@benchmark getall($missings_obj, Query(ismissing))
2628

27-
println("setall")
2829
# Need a wrapper so we don't have to pass in the starting iterator
2930
setall(obj, q, vals)
3031
@benchmark setall($obj, $q, $vals)
32+
# using ProfileView
33+
# @profview for i in 1:1000000 setall(obj, q, vals) end
3134
@code_native setall(obj, q, vals)
3235
@code_warntype setall(obj, q, vals)
3336

3437
# @btime Accessors.set($obj, $slowlens, $vals)
3538
@test setall(obj, q, vals) ==
36-
(7, (a=1.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=2.0,), [1]))
37-
38-
using Cthulhu
39-
@descend getall(obj, q)
40-
# using ProfileView
41-
# @profview for i in 1:1000000 Accessors.set(obj, lens, vals) end
39+
(7, (a=1.0, b=2.0f0), ("3", 4, 5.0), ((x=19, a=2.0,)), [1])
4240

43-
println("unstable set")
4441
unstable_q = Accessors.Query(select=x -> x isa Float64 && x > 2, descend=x -> x isa NamedTuple)
4542
@btime setall($obj, $unstable_q, $vals)
4643
# slow_unstable_lens = Accessors.Query(; select=x -> x isa Number && x > 4, optic=Properties())
@@ -50,10 +47,15 @@ unstable_q = Accessors.Query(select=x -> x isa Float64 && x > 2, descend=x -> x
5047
@btime modify(x -> 10x, $obj, $q)
5148

5249
# Context
53-
obj = (b=2, c=2)
54-
@test context((o, fn) -> fn, obj, q) == (:b, :c)
55-
@test context((o, fn) -> typeof(o), obj, q) == (typeof(obj), typeof(obj))
56-
@btime context((o, fn) -> fn, $obj, $q)
50+
q = Query(;
51+
select=x -> x isa Int,
52+
descend=x -> x isa NamedTuple,
53+
optic = Accessors.Properties()
54+
)
55+
obj2 = (1.0, :a, (b=2, c=2))
56+
@test context((o, fn) -> fn, obj2, q) == (:b, :c)
57+
@test context((o, fn) -> typeof(o), obj2, q) == (typeof(obj2[3]), typeof(obj2[3]))
58+
@btime context((o, fn) -> fn, $obj2, $q)
5759

5860
# Macros
5961
@test (@getall (x for x in missings_obj if x isa Number)) == (1, 2)

0 commit comments

Comments
 (0)