Skip to content

Commit ec26d1c

Browse files
committed
add macros
1 parent e6e851b commit ec26d1c

File tree

3 files changed

+129
-103
lines changed

3 files changed

+129
-103
lines changed

src/optics.jl

Lines changed: 78 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ _constructor(::MaybeConstruct, ::Type{T}) where T = constructorof(T)
135135
struct List end
136136
_constructor(::List, ::Type) = tuple
137137

138-
struct Skip end
139-
_constructor(::Skip, ::Type) = _splat_all
138+
struct Splat end
139+
_constructor(::Splat, ::Type) = _splat_all
140140

141141
_splat_all(args...) = _splat_all(args)
142142
@generated function _splat_all(args::A) where A<:Tuple
@@ -242,7 +242,7 @@ end
242242
abstract type ObjectMap end
243243

244244
OpticStyle(::Type{<:ObjectMap}) = ModifyBased()
245-
modify(f, o, optic::ObjectMap) = mapobject(f, o, optic, Construct, nothing)
245+
modify(f, o, optic::ObjectMap) = mapobject(f, o, optic, Construct)
246246

247247
"""
248248
Properties()
@@ -283,17 +283,17 @@ julia> Accessors.mapobject(x -> x+1, obj)
283283
```
284284
$EXPERIMENTAL
285285
"""
286-
function mapobject(f, obj::O, ::Properties, handler, itr::Nothing) where O
286+
function mapobject(f, obj::O, ::Properties, handler) where O
287287
# TODO move this helper elsewhere?
288288
pnames = propertynames(obj)
289289
if isempty(pnames)
290-
return _maybeskip(handler, obj)
290+
return skip(handler) ? () : obj
291291
else
292-
new_props = map(pnames) do p
292+
ctr = _constructor(handler, O)
293+
args = map(pnames) do p
293294
f(getproperty(obj, p))
294295
end
295-
ctr = _constructor(handler, O)
296-
return ctr(new_props...)
296+
return ctr(args...)
297297
end
298298
end
299299
function mapobject(f, obj::O, ::Properties, handler, itr::Int) where O
@@ -334,68 +334,50 @@ $EXPERIMENTAL
334334
"""
335335
struct Fields <: ObjectMap end
336336

337-
@generated function mapobject(f, obj::O, ::Fields, handler::H, itr::Nothing) where {O,H,I}
337+
@generated function mapobject(f, obj::O, ::Fields, handler::H) where {O,H,I}
338338
# TODO: This is how Flatten.jl works, but it's not really
339339
# correct use of ConstructionBase as it assumers properties=fields
340340
fnames = fieldnames(O)
341341
ctr = _constructor(H(), O)
342342
if isempty(fnames)
343-
:(return _maybeskip(handler, obj))
343+
skip(H()) ? :(()) : :(obj)
344344
else
345-
prop_args = map(fn -> :(getfield(obj, $(QuoteNode(fn)))), fnames)
346-
prop_exp = Expr(:tuple, prop_args...)
347-
new_prop_exp = Expr(:tuple, map(pa -> :(f($pa)), prop_args)...)
348-
quote
349-
props = $prop_exp
350-
new_props = $new_prop_exp
351-
return $ctr(new_props...)
345+
args = map(fnames) do fn
346+
:(f(getfield(obj, $(QuoteNode(fn)))))
352347
end
348+
args_exp = Expr(:tuple, args...)
349+
return :($ctr($args_exp...))
353350
end
354351
end
355352
@generated function mapobject(f, obj::O, ::Fields, handler::H, itr::Int) where {O,H}
356-
# TODO: This is how Flatten.jl works, but it's not really
357-
# correct use of ConstructionBase as it assumers properties=fields
358353
fnames = fieldnames(O)
359-
ctr = _constructor(H(), O)
360354
if isempty(fnames)
361-
:(return (obj, itr) => Unchanged())
355+
:(obj => Unchanged(), itr)
362356
else
363-
prop_args = map(fn -> :(getfield(obj, $(QuoteNode(fn)))), fnames)
364-
prop_exp = Expr(:tuple, prop_args...)
365-
### Unrolled iterating function appliation (it will compile away) ####
366-
# Each function call also updates the iterator value in local scoope with
367-
# the return value from the function. But it only actually inserts the
368-
# value into the parent tuple.
369-
val_exps = map(prop_args) do pa
370-
:(((val, itr), change) = f($pa, itr); val => change)
371-
end
372-
new_prop_exp = Expr(:tuple, val_exps...)
373-
quote
374-
props = $prop_exp
375-
new_props = $new_prop_exp
376-
new_props, change = _splitchanged(new_props)
377-
# Don't construct when we don't absolutely have to.
378-
# `constructorof` may not be defined for an object.
379-
if change isa Changed
380-
return ($ctr(new_props...), itr) => change
381-
else
382-
return (obj, itr) => change
383-
end
357+
### Unrolled iterating function appliation ####
358+
# Each function call updates the iterator value in
359+
# local scoope with its return value
360+
args = map(fnames) do fn
361+
:((val, itr) = f(getfield(obj, $(QuoteNode(fn))), itr); val)
384362
end
363+
args_exp = Expr(:tuple, args...)
364+
return :(_maybeconstruct(obj, $args_exp, handler), itr)
385365
end
386366
end
387367

388-
_splitchanged(props) = map(first, props), _findchanged(map(last, props))
389-
390-
_findchanged(::Tuple{Changed,Vararg}) = Changed()
391-
_findchanged(cs::Tuple) = _findchanged(Base.tail(cs))
392-
_findchanged(::Tuple{}) = Unchanged()
393-
394-
_maybeitr(x, ::Nothing) = x
395-
_maybeitr(x, itr) = x, itr
368+
# Don't construct when we don't absolutely have to.
369+
# `constructorof` may not be defined for an object.
370+
@generated function _maybeconstruct(obj::O, props::P, handler::H) where {O,P,H}
371+
ctr = _constructor(H(), O)
372+
if Changed in map(last fieldtypes, fieldtypes(P))
373+
:($ctr(map(first, props)...) => Changed())
374+
else
375+
:(obj => Unchanged())
376+
end
377+
end
396378

397-
_maybeskip(::Skip, v) = ()
398-
_maybeskip(x, v) = v
379+
skip(::Splat) = true
380+
skip(x) = false
399381

400382
"""
401383
Recursive(descent_condition, optic)
@@ -433,45 +415,8 @@ function _modify(f, obj, r::Recursive, ::ModifyBased)
433415
end
434416
end
435417

436-
################################################################################
437-
##### Lenses
438-
################################################################################
439-
struct PropertyLens{fieldname} end
440-
441-
function (l::PropertyLens{field})(obj) where {field}
442-
getproperty(obj, field)
443-
end
444-
445-
@inline function set(obj, l::PropertyLens{field}, val) where {field}
446-
patch = (;field => val)
447-
setproperties(obj, patch)
448-
end
449-
450-
struct IndexLens{I <: Tuple}
451-
indices::I
452-
end
453-
454-
Base.@propagate_inbounds function (lens::IndexLens)(obj)
455-
getindex(obj, lens.indices...)
456-
end
457-
Base.@propagate_inbounds function set(obj, lens::IndexLens, val)
458-
setindex(obj, val, lens.indices...)
459-
end
460-
461-
struct DynamicIndexLens{F}
462-
f::F
463-
end
464-
465-
Base.@propagate_inbounds function (lens::DynamicIndexLens)(obj)
466-
return obj[lens.f(obj)...]
467-
end
468-
469-
Base.@propagate_inbounds function set(obj, lens::DynamicIndexLens, val)
470-
return setindex(obj, val, lens.f(obj)...)
471-
end
472-
473418
"""
474-
Query(select, descend)
419+
Query(select, descend, optic)
475420
476421
Query an object recursively, choosing fields when `select`
477422
returns `true`, and descending when `descend`.
@@ -501,8 +446,10 @@ end
501446
Query(select, descend = x -> true) = Query(select, descend, Fields())
502447
Query(; select=Any, descend=x -> true, optic=Fields()) = Query(select, descend, optic)
503448

449+
OpticStyle(::Type{<:Query}) = SetBased()
450+
504451
function (q::Query)(obj)
505-
mapobject(obj, _inner(q.optic), Skip(), nothing) do o
452+
mapobject(obj, _inner(q.optic), Splat()) do o
506453
if q.select_condition(o)
507454
(_getouter(o, q.optic),)
508455
elseif q.descent_condition(o)
@@ -518,11 +465,11 @@ set(obj, q::Query, vals) = _set(obj, q::Query, (vals, 1))[1][1]
518465
function _set(obj, q::Query, (vals, itr))
519466
mapobject(obj, _inner(q.optic), MaybeConstruct(), itr) do o, itr
520467
if q.select_condition(o)
521-
(_setouter(o, q.optic, vals[itr]), itr + 1) => Changed()
468+
_setouter(o, q.optic, vals[itr]) => Changed(), itr + 1
522469
elseif q.descent_condition(o)
523-
_set(o, q, (vals, itr))
470+
_set(o, q, (vals, itr)) # Will be marked as Changed()/Unchanged()
524471
else
525-
(o, itr) => Unchanged()
472+
o => Unchanged(), itr
526473
end
527474
end
528475
end
@@ -535,3 +482,40 @@ _getouter(o, optic::ComposedOptic) = optic.outer(o)
535482
_getouter(o, optic) = o
536483
_setouter(o, optic::ComposedOptic, v) = set(o, optic.outer, v)
537484
_setouter(o, optic, v) = v
485+
486+
################################################################################
487+
##### Lenses
488+
################################################################################
489+
struct PropertyLens{fieldname} end
490+
491+
function (l::PropertyLens{field})(obj) where {field}
492+
getproperty(obj, field)
493+
end
494+
495+
@inline function set(obj, l::PropertyLens{field}, val) where {field}
496+
patch = (;field => val)
497+
setproperties(obj, patch)
498+
end
499+
500+
struct IndexLens{I <: Tuple}
501+
indices::I
502+
end
503+
504+
Base.@propagate_inbounds function (lens::IndexLens)(obj)
505+
getindex(obj, lens.indices...)
506+
end
507+
Base.@propagate_inbounds function set(obj, lens::IndexLens, val)
508+
setindex(obj, val, lens.indices...)
509+
end
510+
511+
struct DynamicIndexLens{F}
512+
f::F
513+
end
514+
515+
Base.@propagate_inbounds function (lens::DynamicIndexLens)(obj)
516+
return obj[lens.f(obj)...]
517+
end
518+
519+
Base.@propagate_inbounds function set(obj, lens::DynamicIndexLens, val)
520+
return setindex(obj, val, lens.f(obj)...)
521+
end

src/sugar.jl

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export @set, @optic, @reset, @modify
1+
export @set, @optic, @reset, @modify, @getall, @setall
22
using MacroTools
33

44
"""
@@ -84,7 +84,6 @@ end
8484
This function can be used to create a customized variant of [`@modify`](@ref).
8585
See also [`opticmacro`](@ref), [`setmacro`](@ref).
8686
"""
87-
8887
function modifymacro(optictransform, f, obj_optic)
8988
f = esc(f)
9089
obj, optic = parse_obj_optic(obj_optic)
@@ -94,6 +93,37 @@ function modifymacro(optictransform, f, obj_optic)
9493
end)
9594
end
9695

96+
"""
97+
@getall f(obj, arg...)
98+
99+
@getall obj isa Number
100+
"""
101+
macro getall(ex)
102+
ex.head == :call || error("@getall must be a function call")
103+
obj = ex.args[2]
104+
var = gensym()
105+
ex.args[2] = var
106+
esc(:(Query($var -> $ex)($obj)))
107+
end
108+
109+
"""
110+
@setall f(obj, arg...) = values
111+
112+
113+
"""
114+
macro setall(ex)
115+
ex.head == :(=) || error("@setall must contain an = assignment")
116+
func = ex.args[1]
117+
vals = ex.args[2]
118+
func.head == :call || error("@setall must contain a function call")
119+
obj = func.args[2]
120+
var = gensym()
121+
func.args[2] = var
122+
esc(:(set($obj, Query($var -> $func), $vals)))
123+
end
124+
125+
dump(:(a = b))
126+
97127
foldtree(op, init, x) = op(init, x)
98128
foldtree(op, init, ex::Expr) =
99129
op(foldl((acc, x) -> foldtree(op, acc, x), ex.args; init=init), ex)

test/test_queries.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ slowlens = Query(;
1515
optic = (Accessors.@optic _.a) Accessors.Properties()
1616
)
1717

18-
@code_typed lens(obj)
19-
@code_typed slowlens(obj)
18+
lens(obj)
19+
@code_warntype lens(obj)
20+
@code_warntype slowlens(obj)
2021

2122
@code_native lens(obj)
2223
@code_native slowlens(obj)
@@ -28,26 +29,37 @@ println("get")
2829

2930
missings_obj = (a=missing, b=1, c=(d=missing, e=(f=missing, g=2)))
3031
@test Query(ismissing)(missings_obj) === (missing, missing, missing)
32+
@btime Query(ismissing)($missings_obj) === (missing, missing, missing)
3133

3234
println("set")
3335
# Need a wrapper so we don't have to pass in the starting iterator
34-
@btime Accessors.set($obj, $lens, $vals)
36+
set(obj, lens, vals)
37+
@btime set($obj, $lens, $vals)
3538
@btime Accessors._set($obj, $lens, ($vals, 1))[1]
3639
# @btime Accessors.set($obj, $slowlens, $vals)
37-
Accessors.set(obj, lens, vals)
3840
@test Accessors.set(obj, lens, vals) ==
3941
Accessors.set(obj, lens, vals) ==
4042
(7, (a=1.0, b=2.0f0), ("3", 4, 5.0), ((a=2.0,), [1]))
4143

42-
@code_warntype Accessors.set(obj, lens, vals)
44+
@code_warntype set(obj, lens, vals)
45+
@code_native set(obj, lens, vals)
46+
@code_native Accessors._set(obj, lens, (vals, 1))[1]
47+
48+
# using Cthulhu
49+
# using ProfileView
50+
# @profview for i in 1:1000000 Accessors.set(obj, lens, vals) end
51+
# @descend Accessors.set(obj, lens, vals)
4352

4453
println("unstable set")
4554
unstable_lens = Accessors.Query(select=x -> x isa Float64 && x > 2, descend=x -> x isa NamedTuple)
46-
@btime Accessors.set($obj, $unstable_lens, $vals)
55+
@btime set($obj, $unstable_lens, $vals)
4756
# slow_unstable_lens = Accessors.Query(; select=x -> x isa Number && x > 4, optic=Properties())
4857
# @btime Accessors.set($obj, $slow_unstable_lens, $vals))
4958

5059
# Somehow modify compiles away almost completely
5160
@btime modify(x -> 10x, $obj, $lens)
52-
@test modify(x -> 10x, obj, lens) == (7, (a=170.0, b=2.0f0), ("3", 4, 5.0), ((a=60.0,), [1]))
5361

62+
# Macros
63+
@test (@getall missings_obj isa Number) == (1, 2)
64+
expected = (a=missing, b=5, c=(d=missing, e=(f=missing, g=6)))
65+
@test (@setall missings_obj isa Number = (5, 6)) === expected

0 commit comments

Comments
 (0)