@@ -54,86 +54,83 @@ of the expression.
5454"""
5555declare_operator_alias (op:: F , _) where {F<: Function } = op
5656
57- function apply_operator (op:: F , l:: AbstractExpression ) where {F<: Function }
58- operators = get_operators (l, nothing )
59- op_idx = findfirst (
60- == (op), map (Base. Fix2 (declare_operator_alias, Val (1 )), operators. unaops)
61- )
62- if op_idx === nothing
57+ function apply_operator (op:: F , args:: Vararg{Any,D} ) where {F<: Function ,D}
58+ idx = findfirst (e -> e isa AbstractExpression, args):: Int
59+ example_expr = args[idx]
60+ E = typeof (example_expr)
61+ @assert all (e -> ! (e isa AbstractExpression) || typeof (e) === E, args)
62+ operators = get_operators (example_expr, nothing )
63+
64+ op_idx = findfirst (== (op), map (Base. Fix2 (declare_operator_alias, Val (D)), operators[D]))
65+ if isnothing (op_idx)
6366 throw (
6467 MissingOperatorError (
65- " Operator $op not found in operators for expression type $(typeof (l)) with unary operators $(operators. unaops) " ,
68+ " Operator $op not found in operators for expression type " *
69+ " $(typeof (l)) with $(D) -degree operators $(operators[D]) " ,
6670 ),
6771 )
6872 end
69- return insert_operator_index (op_idx, (l,), l)
70- end
71- function apply_operator (op:: F , l, r) where {F<: Function }
72- (operators, example_expr) = if l isa AbstractExpression && r isa AbstractExpression
73- @assert typeof (r) === typeof (l)
74- (get_operators (l, nothing ), l)
75- elseif l isa AbstractExpression
76- (get_operators (l, nothing ), l)
77- else
78- r:: AbstractExpression
79- (get_operators (r, nothing ), r)
80- end
81- op_idx = findfirst (
82- == (op), map (Base. Fix2 (declare_operator_alias, Val (2 )), operators. binops)
83- )
84- if op_idx === nothing
85- throw (
86- MissingOperatorError (
87- " Operator $op not found in operators for expression type $(typeof (l)) with binary operators $(operators. binops) " ,
88- ),
89- )
90- end
91- return insert_operator_index (op_idx, (l, r), example_expr)
73+ return insert_operator_index (op_idx, args, example_expr)
9274end
9375
9476"""
9577 @declare_expression_operator(op, arity)
9678
9779Declare an operator function for `AbstractExpression` types.
9880
99- This macro generates a method for the given operator `op` that works with
100- `AbstractExpression` arguments. The `arity` parameter specifies whether
101- the operator is unary (1) or binary (2).
102-
103- # Arguments
104- - `op`: The operator to be declared (e.g., `Base.sin`, `Base.:+`).
105- - `arity`: The number of arguments the operator takes (1 for unary, 2 for binary).
81+ This macro generates methods for the given operator `op` that work with
82+ `AbstractExpression` arguments. The `arity` parameter specifies the number
83+ of arguments the operator takes.
10684"""
10785macro declare_expression_operator (op, arity)
108- @assert arity ∈ (1 , 2 )
86+ syms = [Symbol (' x' , i) for i in 1 : arity]
87+ AE = :($ (AbstractExpression))
10988 if arity == 1
11089 return esc (
11190 quote
112- $ op (l :: AbstractExpression ) = $ (apply_operator)($ op, l )
91+ $ op ($ ( only (syms)) :: $ (AE)) = $ (apply_operator)($ op, $ ( only (syms)) )
11392 end ,
11493 )
115- elseif arity == 2
116- return esc (
117- quote
118- function $op (l:: AbstractExpression , r:: AbstractExpression )
119- return $ (apply_operator)($ op, l, r)
120- end
121- function $op (l:: T , r:: AbstractExpression{T} ) where {T}
122- return $ (apply_operator)($ op, l, r)
123- end
124- function $op (l:: AbstractExpression{T} , r:: T ) where {T}
125- return $ (apply_operator)($ op, l, r)
126- end
127- # Convenience methods for Number types
128- function $op (l:: Number , r:: AbstractExpression{T} ) where {T}
129- return $ (apply_operator)($ op, l, r)
130- end
131- function $op (l:: AbstractExpression{T} , r:: Number ) where {T}
132- return $ (apply_operator)($ op, l, r)
133- end
134- end ,
94+ end
95+
96+ wrappers = (AE, :($ (AE){T}), :T , :Number )
97+ methods = Expr (:block )
98+
99+ for types in Iterators. product (ntuple (_ -> wrappers, arity)... )
100+ has_expr = any (
101+ t -> t == AE || (t isa Expr && t. head == :curly && t. args[1 ] == AE), types
102+ )
103+ has_plain_T = any (== (:T ), types)
104+ has_abstract_expr_T = any (
105+ t -> t isa Expr && t. head == :curly && t. args[1 ] == AE && :T in t. args, types
135106 )
107+ has_abstract_expr_plain = any (== (AE), types)
108+ if any ((
109+ ! has_expr,
110+ # ^At least one arg must be an AbstractExpression (avoid type‑piracy)
111+ has_abstract_expr_plain && has_abstract_expr_T,
112+ # ^If a plain `T` appears, ensure an `AbstractExpression{T}` is also present
113+ has_plain_T ⊻ has_abstract_expr_T,
114+ # ^Do not mix bare `AbstractExpression` with `AbstractExpression{T}`
115+ ))
116+ continue
117+ end
118+
119+
120+ arglist = [Expr (:(:: ), syms[i], types[i]) for i in 1 : arity]
121+ signature = Expr (:call , op, arglist... )
122+ if any (t -> t == :T || (t isa Expr && t. head == :curly && :T in t. args), types)
123+ signature = Expr (:where , signature, :(T))
124+ end
125+
126+ body = Expr (:block , :(return $ (apply_operator)($ op, $ (syms... ))))
127+
128+ fn = Expr (:function , signature, body)
129+
130+ push! (methods. args, fn)
136131 end
132+
133+ return esc (methods)
137134end
138135
139136# ! format: off
@@ -159,6 +156,11 @@ for op in (
159156)
160157 @eval @declare_expression_operator Base.$ (op) 2
161158end
159+ for op in (
160+ :* , :+ , :clamp , :max , :min , :fma , :muladd ,
161+ )
162+ @eval @declare_expression_operator Base.$ (op) 3
163+ end
162164# ! format: on
163165
164166end
0 commit comments