Skip to content

Commit 83b5bb8

Browse files
committed
composing
1 parent aa95687 commit 83b5bb8

File tree

2 files changed

+41
-52
lines changed

2 files changed

+41
-52
lines changed

src/optics.jl

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,6 @@ struct ConstructIfChanged{C}
133133
constructor::C
134134
end
135135

136-
_constructor(x, t) = _constructor(x.handler, t)
137-
138136
# TODO what do we call these things?
139137
struct Construct end
140138
_constructor(::Construct, ::Type{T}) where T = constructorof(T)
@@ -251,19 +249,6 @@ function modify(f, obj, w::If)
251249
end
252250
end
253251

254-
struct Select{C}
255-
select_condition::C
256-
end
257-
OpticStyle(::Type{<:If}) = ModifyBased()
258-
259-
function modify(f, obj, w::If)
260-
if w.modify_condition(obj)
261-
(obj,)
262-
else
263-
()
264-
end
265-
end
266-
267252
"""
268253
mapproperties(f, obj)
269254
@@ -280,7 +265,7 @@ julia> Accessors.mapproperties(x -> x+1, obj)
280265
```
281266
$EXPERIMENTAL
282267
"""
283-
function mapproperties(f, obj::O, optic, itr::Nothing=nothing) where O
268+
function mapproperties(f, obj::O, handler=Construct(), itr::Nothing=nothing) where O
284269
# TODO move this helper elsewhere?
285270
pnames = propertynames(obj)
286271
if isempty(pnames)
@@ -290,11 +275,11 @@ function mapproperties(f, obj::O, optic, itr::Nothing=nothing) where O
290275
new_props = map(pnames) do p
291276
f(getproperty(obj, p))
292277
end
293-
ctr = _constructor(optic, O)
278+
ctr = _constructor(handler, O)
294279
return ctr(new_props...)
295280
end
296281
end
297-
function mapproperties(f, obj::O, optic, itr::Int) where O
282+
function mapproperties(f, obj::O, handler, itr::Int) where O
298283
pnames = propertynames(obj)
299284
if isempty(pnames)
300285
return _maybeitr(obj, itr)
@@ -304,7 +289,7 @@ function mapproperties(f, obj::O, optic, itr::Int) where O
304289
val, itr = f(getproperty(obj, p), itr)
305290
(vals..., val), itr
306291
end
307-
ctr = _constructor(optic, O)
292+
ctr = _constructor(handler, O)
308293
return _maybeitr(ctr(new_props...), itr)
309294
end
310295
end
@@ -330,10 +315,7 @@ Based on [`mapproperties`](@ref).
330315
331316
$EXPERIMENTAL
332317
"""
333-
struct Properties{H}
334-
handler::H
335-
end
336-
Properties() = Properties(Construct())
318+
struct Properties end
337319
OpticStyle(::Type{<:Properties}) = ModifyBased()
338320
modify(f, o, ::Properties) = mapproperties(f, o)
339321

@@ -353,7 +335,7 @@ julia> Accessors.mapfields(x -> x+1, obj)
353335
```
354336
$EXPERIMENTAL
355337
"""
356-
@generated function mapfields(f, obj::O, optic, itr::I=nothing) where {O,H,I}
338+
@generated function mapfields(f, obj::O, handler::H=Construct(), itr::I=nothing) where {O,H,I}
357339
# TODO: This is how Flatten.jl works, but it's not really
358340
# correct use of ConstructionBase as it assumers properties=fields
359341
fnames = fieldnames(O)
@@ -424,12 +406,9 @@ Based on [`mapfields`](@ref).
424406
425407
$EXPERIMENTAL
426408
"""
427-
struct Fields{H}
428-
handler::H
429-
end
430-
Fields() = Fields(Construct())
409+
struct Fields end
431410
OpticStyle(::Type{<:Fields}) = ModifyBased()
432-
modify(f, o, optic::Fields) = mapfields(f, o, optic)
411+
modify(f, o, ::Fields) = mapfields(f, o)
433412

434413
"""
435414
Recursive(descent_condition, optic)
@@ -528,18 +507,22 @@ julia> modify(x -> 100x, obj, Recursive(x -> (x isa Tuple), Elements()))
528507
```
529508
$EXPERIMENTAL
530509
"""
531-
struct Query{Select,Descend,Optic}
510+
struct Query{Select,Descend,Optic<:Union{ComposedOptic,Fields,Properties}}
532511
select_condition::Select
533512
descent_condition::Descend
534513
optic::Optic
535514
end
536515
Query(select, descend = x -> true) = Query(select, descend, Fields())
537516
Query(; select=Any, descend=x -> true, optic=Fields()) = Query(select, descend, optic)
538517

518+
_inner(optic::ComposedOptic) = optic.inner
519+
_inner(optic::Fields) = optic
520+
_inner(optic::Properties) = optic
521+
539522
function (q::Query)(obj)
540-
_query(obj, q.optic, Splat(), nothing) do o
523+
_query(obj, _inner(q.optic), Splat(), nothing) do o
541524
if q.select_condition(o)
542-
(o,)
525+
(_getouter(o, q.optic),)
543526
elseif q.descent_condition(o)
544527
q(o)
545528
else
@@ -548,12 +531,15 @@ function (q::Query)(obj)
548531
end
549532
end
550533

534+
_getouter(o, optic::ComposedOptic) = optic.outer(o)
535+
_getouter(o, optic) = o
536+
551537
_set(obj, q::Query, val, ::SetBased) = _setquery(obj, q::Query, (val, 1))[1]
552538

553539
function _setquery(obj, q::Query, (val, itr))
554-
_query(obj, q.optic, Construct(), itr) do o, itr
540+
_query(obj, _inner(q.optic), Construct(), itr) do o, itr
555541
if q.select_condition(o)
556-
val[itr], itr + 1
542+
_setouter(o, q.optic, val[itr]), itr + 1
557543
elseif q.descent_condition(o)
558544
_setquery(o, q, (val, itr))
559545
else
@@ -562,8 +548,10 @@ function _setquery(obj, q::Query, (val, itr))
562548
end
563549
end
564550

551+
_setouter(o, optic::ComposedOptic, v) = set(o, optic.outer, v)
552+
_setouter(o, optic, v) = v
553+
565554
modify(f, obj, q::Query) = set(obj, q, map(f, q(obj)))
566555

567-
_query(f, o, ::Elements, itr) = map(f, o)
568-
_query(f, o, ::Fields, itr) = mapfields(f, o, itr)
569-
_query(f, o, ::Properties, itr) = mapproperties(f, o, itr)
556+
_query(f, o, ::Fields, handler, itr) = mapfields(f, o, handler, itr)
557+
_query(f, o, ::Properties, handler, itr) = mapproperties(f, o, handler, itr)

test/test_queries.jl

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
using Accessors, Test, BenchmarkTools
22

3-
obj = (7, (a=17.0, b=2.0f0), ("3", 4, 5.0), (x=6.0,), [1])
3+
obj = (7, (a=17.0, b=2.0f0), ("3", 4, 5.0), ((a=6.0,), [1]))
44
vals = (1.0, 2.0, 3.0, 4.0)
55

66
# Fields is the default
7-
lens = Query(; select=x -> x isa Float64, descend=x -> x isa NamedTuple)
8-
slowlens = Query(x -> x isa Float64, x -> x isa NamedTuple, Accessors.Properties())
7+
lens = Query(;
8+
select=x -> x isa NamedTuple,
9+
descend=x -> x isa Tuple,
10+
optic = (Accessors.@optic _.a) Accessors.Fields()
11+
)
12+
slowlens = Query(;
13+
select=x -> x isa NamedTuple,
14+
descend=x -> x isa Tuple,
15+
optic = (Accessors.@optic _.a) Accessors.Properties()
16+
)
917

1018
@code_typed lens(obj)
1119
@code_typed slowlens(obj)
@@ -22,12 +30,13 @@ println("set")
2230
# Need a wrapper so we don't have to pass in the starting iterator
2331
@btime Accessors.set($obj, $lens, $vals)
2432
@btime Accessors.set($obj, $slowlens, $vals)
33+
Accessors.set(obj, lens, vals)
2534
@test Accessors.set(obj, lens, vals) ==
2635
Accessors.set(obj, lens, vals) ==
27-
(7, (a=1.0, b=2.0f0), ("3", 4, 5.0), (x=2.0,), [1])
36+
(7, (a=1.0, b=2.0f0), ("3", 4, 5.0), ((a=2.0,), [1]))
2837

29-
using ProfileView
30-
@profview for i in 1:1000000 Accessors.set(obj, lens, vals) end
38+
# using ProfileView
39+
# @profview for i in 1:1000000 Accessors.set(obj, lens, vals) end
3140

3241
println("unstable set")
3342
unstable_lens = Accessors.Query(select=x -> x isa Float64 && x > 2, descend=x -> x isa NamedTuple)
@@ -37,12 +46,4 @@ unstable_lens = Accessors.Query(select=x -> x isa Float64 && x > 2, descend=x ->
3746

3847
# Somehow modify compiles away almost completely
3948
@btime modify(x -> 10x, $obj, $lens)
40-
@test modify(x -> 10x, obj, lens) == (7, (a=170.0, b=2.0f0), ("3", 4, 5.0), (x=60.0,), [1])
41-
using ProfileView
42-
@profview for i in 1:1000000 modify(x -> 10x, obj, lens) end
43-
44-
45-
obj = (1,2,3,4,5,6);
46-
47-
@macroexpand
48-
@set obj |> Properties() |> If(iseven)
49+
@test modify(x -> 10x, obj, lens) == (7, (a=170.0, b=2.0f0), ("3", 4, 5.0), ((a=60.0,), [1]))

0 commit comments

Comments
 (0)