Skip to content

Commit 21d2e02

Browse files
authored
support broadcasting (#137)
* support broadcasting in optics * remove _UpdateOp * support @set with broadcasting
1 parent cf53adf commit 21d2e02

File tree

3 files changed

+53
-10
lines changed

3 files changed

+53
-10
lines changed

src/functionlenses.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ end
8989

9090
set(obj, o::Base.Fix1{typeof(map)}, val) = map((ob, v) -> set(ob, o.x, v), obj, val)
9191

92+
set(obj, o::Base.BroadcastFunction, val) = set.(obj, Ref(o.f), val)
93+
set(obj, o::Base.Fix1{<:Base.BroadcastFunction}, val) = set.(obj, Base.Fix1.(Ref(o.f.f), o.x), val)
94+
set(obj, o::Base.Fix2{<:Base.BroadcastFunction}, val) = set.(obj, Base.Fix2.(Ref(o.f.f), o.x), val)
95+
9296
set(obj, o::Base.Fix1{typeof(filter)}, val) = @set obj[findall(o.x, obj)] = val
9397
modify(f, obj, o::Base.Fix1{typeof(filter)}) = @modify(f, obj[findall(o.x, obj)])
9498
delete(obj, o::Base.Fix1{typeof(filter)}) = filter(!o.x, obj)

src/sugar.jl

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,21 @@ function lower_index(collection::Symbol, index, dim)
157157
)
158158
end
159159

160+
_secondarg(_, x) = x
161+
162+
_esc_and_dot_name_to_broadcasted(f) = esc(f)
163+
_esc_and_dot_name_to_broadcasted(f::Symbol) =
164+
if f == :.
165+
# eg, in @set a[:] .= 1
166+
# the returned function will be called as func(a, 1)
167+
:(Base.BroadcastFunction($_secondarg))
168+
elseif startswith(string(f), '.')
169+
# eg, in @set a[:] .+= 1 or @optic _ .+ 1
170+
:(Base.BroadcastFunction($(esc(Symbol(string(f)[2:end])))))
171+
else
172+
esc(f)
173+
end
174+
160175
function parse_obj_optics(ex)
161176
dollar_exprs = foldtree([], ex) do exs, x
162177
x isa Expr && x.head == :$ ?
@@ -206,9 +221,15 @@ function parse_obj_optics(ex)
206221
obj, frontoptic = parse_obj_optics(front)
207222
optic = :($PropertyLens{$(QuoteNode(property))}())
208223
elseif @capture(ex, f_(front_))
224+
# regular function call
209225
obj, frontoptic = parse_obj_optics(front)
210-
optic = esc(f) # function optic
226+
optic = _esc_and_dot_name_to_broadcasted(f) # broadcasted operators like .- fall here
227+
elseif @capture(ex, f_.(front_))
228+
# broadcasted function call (not operator)
229+
obj, frontoptic = parse_obj_optics(front)
230+
optic = :(Base.BroadcastFunction($(esc(f))))
211231
elseif @capture(ex, f_(args__))
232+
# function call with multiple arguments
212233
args_contain_under = map(args) do arg
213234
foldtree((yes, x) -> yes || x === :_, false, arg)
214235
end
@@ -219,12 +240,13 @@ function parse_obj_optics(ex)
219240
end
220241
length(args) == 2 || error("Only 1- and 2-argument functions are supported")
221242
sum(args_contain_under) == 1 || error("Only a single function argument can be the optic target")
243+
f = _esc_and_dot_name_to_broadcasted(f) # multi-arg broadcasted fall here, no matter if regular function or operator
222244
if args_contain_under[1]
223245
obj, frontoptic = parse_obj_optics(args[1])
224-
optic = :(Base.Fix2($(esc(f)), $(esc(args[2]))))
246+
optic = :(Base.Fix2($f, $(esc(args[2]))))
225247
elseif args_contain_under[2]
226248
obj, frontoptic = parse_obj_optics(args[2])
227-
optic = :(Base.Fix1($(esc(f)), $(esc(args[1]))))
249+
optic = :(Base.Fix1($f, $(esc(args[1]))))
228250
end
229251
else
230252
obj = esc(ex)
@@ -261,12 +283,6 @@ function get_update_op(sym::Symbol)
261283
Symbol(s[1:end-1])
262284
end
263285

264-
struct _UpdateOp{OP,V}
265-
op::OP
266-
val::V
267-
end
268-
(u::_UpdateOp)(x) = u.op(x, u.val)
269-
270286
"""
271287
setmacro(optictransform, ex::Expr; overwrite::Bool=false)
272288
@@ -300,7 +316,7 @@ function setmacro(optictransform, ex::Expr; overwrite::Bool=false)
300316
:($set($obj, ($optictransform)($optic), $val))
301317
else
302318
op = get_update_op(ex.head)
303-
f = :($_UpdateOp($op,$val))
319+
f = :($Base.Fix2($(_esc_and_dot_name_to_broadcasted(op)), $val))
304320
:($modify($f, $obj, ($optictransform)($optic)))
305321
end
306322
return _macro_expression_result(obj, ret; overwrite=overwrite)

test/test_functionlenses.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,29 @@ end
198198
test_getset_laws(diag, [1 2 3; 4 5 6], [1., 2.5], [0, 1])
199199
end
200200

201+
@testset "broadcast" begin
202+
# in optic definiton
203+
@test (@optic exp.(_)) === Base.BroadcastFunction(exp)
204+
@test (@optic .-_) === Base.BroadcastFunction(-)
205+
@test (@optic 1 .+ _) === Base.Fix1(Base.BroadcastFunction(+), 1)
206+
A = [[1,2], [1,2,3], [1,2,3,4]]
207+
@test (@set first.(A) = 10:12) == [[10, 2], [11, 2, 3], [12, 2, 3, 4]]
208+
209+
test_getset_laws((@optic first.(_)), [[1,2], [1,2,3]], [3., 4.], [5, 6])
210+
test_getset_laws((@optic _ .- 1), [1, 2], [3., 4.], [5, 6])
211+
test_getset_laws((@optic 1 .- _), [1, 2], [3., 4.], [5, 6])
212+
test_getset_laws((@optic [10, 20] .* first.(_)), [[1,2], [1,2,3]], [3., 4.], [5, 6])
213+
214+
# in @set update operation
215+
@test (@set A .= 1) == [1, 1, 1]
216+
@test (@set A[1] .= 1) == [[1,1], [1,2,3], [1,2,3,4]]
217+
@test (@set A[1] .= [10, 20]) == [[10,20], [1,2,3], [1,2,3,4]]
218+
219+
# broadcasting in both optic and set, with a user-defined function
220+
@accessor f(x) = x+1
221+
@test (@set f.(first.(A)) .= [10, 20, 30]) == [[9,2], [19,2,3], [29,2,3,4]]
222+
end
223+
201224
@testset "math" begin
202225
@test 2.0 === @set real(1) = 2.0
203226
@test 2.0 + 1im === @set real(1+1im) = 2.0

0 commit comments

Comments
 (0)