@@ -157,6 +157,21 @@ function lower_index(collection::Symbol, index, dim)
157
157
)
158
158
end
159
159
160
+ _secondarg (_, x) = x
161
+
162
+ _esc_and_dot_name_to_broadcasted (f) = esc (f)
163
+ _esc_and_dot_name_to_broadcasted (f:: Symbol ) =
164
+ if f == :.
165
+ # eg, in @set a[:] .= 1
166
+ # the returned function will be called as func(a, 1)
167
+ :(Base. BroadcastFunction ($ _secondarg))
168
+ elseif startswith (string (f), ' .' )
169
+ # eg, in @set a[:] .+= 1 or @optic _ .+ 1
170
+ :(Base. BroadcastFunction ($ (esc (Symbol (string (f)[2 : end ])))))
171
+ else
172
+ esc (f)
173
+ end
174
+
160
175
function parse_obj_optics (ex)
161
176
dollar_exprs = foldtree ([], ex) do exs, x
162
177
x isa Expr && x. head == :$ ?
@@ -206,9 +221,15 @@ function parse_obj_optics(ex)
206
221
obj, frontoptic = parse_obj_optics (front)
207
222
optic = :($ PropertyLens {$(QuoteNode(property))} ())
208
223
elseif @capture (ex, f_ (front_))
224
+ # regular function call
209
225
obj, frontoptic = parse_obj_optics (front)
210
- optic = esc (f) # function optic
226
+ optic = _esc_and_dot_name_to_broadcasted (f) # broadcasted operators like .- fall here
227
+ elseif @capture (ex, f_ .(front_))
228
+ # broadcasted function call (not operator)
229
+ obj, frontoptic = parse_obj_optics (front)
230
+ optic = :(Base. BroadcastFunction ($ (esc (f))))
211
231
elseif @capture (ex, f_ (args__))
232
+ # function call with multiple arguments
212
233
args_contain_under = map (args) do arg
213
234
foldtree ((yes, x) -> yes || x === :_ , false , arg)
214
235
end
@@ -219,12 +240,13 @@ function parse_obj_optics(ex)
219
240
end
220
241
length (args) == 2 || error (" Only 1- and 2-argument functions are supported" )
221
242
sum (args_contain_under) == 1 || error (" Only a single function argument can be the optic target" )
243
+ f = _esc_and_dot_name_to_broadcasted (f) # multi-arg broadcasted fall here, no matter if regular function or operator
222
244
if args_contain_under[1 ]
223
245
obj, frontoptic = parse_obj_optics (args[1 ])
224
- optic = :(Base. Fix2 ($ ( esc (f)) , $ (esc (args[2 ]))))
246
+ optic = :(Base. Fix2 ($ f , $ (esc (args[2 ]))))
225
247
elseif args_contain_under[2 ]
226
248
obj, frontoptic = parse_obj_optics (args[2 ])
227
- optic = :(Base. Fix1 ($ ( esc (f)) , $ (esc (args[1 ]))))
249
+ optic = :(Base. Fix1 ($ f , $ (esc (args[1 ]))))
228
250
end
229
251
else
230
252
obj = esc (ex)
@@ -261,12 +283,6 @@ function get_update_op(sym::Symbol)
261
283
Symbol (s[1 : end - 1 ])
262
284
end
263
285
264
- struct _UpdateOp{OP,V}
265
- op:: OP
266
- val:: V
267
- end
268
- (u:: _UpdateOp )(x) = u. op (x, u. val)
269
-
270
286
"""
271
287
setmacro(optictransform, ex::Expr; overwrite::Bool=false)
272
288
@@ -300,7 +316,7 @@ function setmacro(optictransform, ex::Expr; overwrite::Bool=false)
300
316
:($ set ($ obj, ($ optictransform)($ optic), $ val))
301
317
else
302
318
op = get_update_op (ex. head)
303
- f = :($ _UpdateOp ( $ op, $ val))
319
+ f = :($ Base . Fix2 ( $ ( _esc_and_dot_name_to_broadcasted (op)), $ val))
304
320
:($ modify ($ f, $ obj, ($ optictransform)($ optic)))
305
321
end
306
322
return _macro_expression_result (obj, ret; overwrite= overwrite)
0 commit comments