Skip to content

Commit 9ccb46a

Browse files
committed
feat: get expression operators working with 3-arg input
1 parent 673146c commit 9ccb46a

File tree

3 files changed

+65
-61
lines changed

3 files changed

+65
-61
lines changed

src/ExpressionAlgebra.jl

Lines changed: 61 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -54,86 +54,83 @@ of the expression.
5454
"""
5555
declare_operator_alias(op::F, _) where {F<:Function} = op
5656

57-
function apply_operator(op::F, l::AbstractExpression) where {F<:Function}
58-
operators = get_operators(l, nothing)
59-
op_idx = findfirst(
60-
==(op), map(Base.Fix2(declare_operator_alias, Val(1)), operators.unaops)
61-
)
62-
if op_idx === nothing
57+
function apply_operator(op::F, args::Vararg{Any,D}) where {F<:Function,D}
58+
idx = findfirst(e -> e isa AbstractExpression, args)::Int
59+
example_expr = args[idx]
60+
E = typeof(example_expr)
61+
@assert all(e -> !(e isa AbstractExpression) || typeof(e) === E, args)
62+
operators = get_operators(example_expr, nothing)
63+
64+
op_idx = findfirst(==(op), map(Base.Fix2(declare_operator_alias, Val(D)), operators[D]))
65+
if isnothing(op_idx)
6366
throw(
6467
MissingOperatorError(
65-
"Operator $op not found in operators for expression type $(typeof(l)) with unary operators $(operators.unaops)",
68+
"Operator $op not found in operators for expression type " *
69+
"$(typeof(l)) with $(D)-degree operators $(operators[D])",
6670
),
6771
)
6872
end
69-
return insert_operator_index(op_idx, (l,), l)
70-
end
71-
function apply_operator(op::F, l, r) where {F<:Function}
72-
(operators, example_expr) = if l isa AbstractExpression && r isa AbstractExpression
73-
@assert typeof(r) === typeof(l)
74-
(get_operators(l, nothing), l)
75-
elseif l isa AbstractExpression
76-
(get_operators(l, nothing), l)
77-
else
78-
r::AbstractExpression
79-
(get_operators(r, nothing), r)
80-
end
81-
op_idx = findfirst(
82-
==(op), map(Base.Fix2(declare_operator_alias, Val(2)), operators.binops)
83-
)
84-
if op_idx === nothing
85-
throw(
86-
MissingOperatorError(
87-
"Operator $op not found in operators for expression type $(typeof(l)) with binary operators $(operators.binops)",
88-
),
89-
)
90-
end
91-
return insert_operator_index(op_idx, (l, r), example_expr)
73+
return insert_operator_index(op_idx, args, example_expr)
9274
end
9375

9476
"""
9577
@declare_expression_operator(op, arity)
9678
9779
Declare an operator function for `AbstractExpression` types.
9880
99-
This macro generates a method for the given operator `op` that works with
100-
`AbstractExpression` arguments. The `arity` parameter specifies whether
101-
the operator is unary (1) or binary (2).
102-
103-
# Arguments
104-
- `op`: The operator to be declared (e.g., `Base.sin`, `Base.:+`).
105-
- `arity`: The number of arguments the operator takes (1 for unary, 2 for binary).
81+
This macro generates methods for the given operator `op` that work with
82+
`AbstractExpression` arguments. The `arity` parameter specifies the number
83+
of arguments the operator takes.
10684
"""
10785
macro declare_expression_operator(op, arity)
108-
@assert arity (1, 2)
86+
syms = [Symbol('x', i) for i in 1:arity]
87+
AE = :($(AbstractExpression))
10988
if arity == 1
11089
return esc(
11190
quote
112-
$op(l::AbstractExpression) = $(apply_operator)($op, l)
91+
$op($(only(syms))::$(AE)) = $(apply_operator)($op, $(only(syms)))
11392
end,
11493
)
115-
elseif arity == 2
116-
return esc(
117-
quote
118-
function $op(l::AbstractExpression, r::AbstractExpression)
119-
return $(apply_operator)($op, l, r)
120-
end
121-
function $op(l::T, r::AbstractExpression{T}) where {T}
122-
return $(apply_operator)($op, l, r)
123-
end
124-
function $op(l::AbstractExpression{T}, r::T) where {T}
125-
return $(apply_operator)($op, l, r)
126-
end
127-
# Convenience methods for Number types
128-
function $op(l::Number, r::AbstractExpression{T}) where {T}
129-
return $(apply_operator)($op, l, r)
130-
end
131-
function $op(l::AbstractExpression{T}, r::Number) where {T}
132-
return $(apply_operator)($op, l, r)
133-
end
134-
end,
94+
end
95+
96+
wrappers = (AE, :($(AE){T}), :T, :Number)
97+
methods = Expr(:block)
98+
99+
for types in Iterators.product(ntuple(_ -> wrappers, arity)...)
100+
has_expr = any(
101+
t -> t == AE || (t isa Expr && t.head == :curly && t.args[1] == AE), types
102+
)
103+
has_plain_T = any(==(:T), types)
104+
has_abstract_expr_T = any(
105+
t -> t isa Expr && t.head == :curly && t.args[1] == AE && :T in t.args, types
135106
)
107+
has_abstract_expr_plain = any(==(AE), types)
108+
if any((
109+
!has_expr,
110+
# ^At least one arg must be an AbstractExpression (avoid type‑piracy)
111+
has_abstract_expr_plain && has_abstract_expr_T,
112+
# ^If a plain `T` appears, ensure an `AbstractExpression{T}` is also present
113+
has_plain_T has_abstract_expr_T,
114+
# ^Do not mix bare `AbstractExpression` with `AbstractExpression{T}`
115+
))
116+
continue
117+
end
118+
119+
120+
arglist = [Expr(:(::), syms[i], types[i]) for i in 1:arity]
121+
signature = Expr(:call, op, arglist...)
122+
if any(t -> t == :T || (t isa Expr && t.head == :curly && :T in t.args), types)
123+
signature = Expr(:where, signature, :(T))
124+
end
125+
126+
body = Expr(:block, :(return $(apply_operator)($op, $(syms...))))
127+
128+
fn = Expr(:function, signature, body)
129+
130+
push!(methods.args, fn)
136131
end
132+
133+
return esc(methods)
137134
end
138135

139136
#! format: off
@@ -159,6 +156,11 @@ for op in (
159156
)
160157
@eval @declare_expression_operator Base.$(op) 2
161158
end
159+
for op in (
160+
:*, :+, :clamp, :max, :min, :fma, :muladd,
161+
)
162+
@eval @declare_expression_operator Base.$(op) 3
163+
end
162164
#! format: on
163165

164166
end

src/OperatorEnum.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module OperatorEnumModule
22

3+
using DispatchDoctor: @unstable
4+
35
abstract type AbstractOperatorEnum end
46

57
"""
@@ -37,7 +39,7 @@ end
3739
Base.copy(op::AbstractOperatorEnum) = op
3840
# TODO: Is this safe? What if a vector is passed here?
3941

40-
@inline function Base.getindex(op::AbstractOperatorEnum, i::Int)
42+
@unstable @inline function Base.getindex(op::AbstractOperatorEnum, i::Int)
4143
return getfield(op, :ops)[i]
4244
end
4345
@inline function Base.getproperty(op::AbstractOperatorEnum, k::Symbol)

src/Strings.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function dispatch_op_name(
2424
return collect((pretty ? get_pretty_op_name(op) : get_op_name(op))::String)
2525
end
2626

27-
struct OpNameDispatcher{D,O<:AbstractOperatorEnum} <: Function
27+
struct OpNameDispatcher{D,O<:Union{AbstractOperatorEnum,Nothing}} <: Function
2828
operators::O
2929
pretty::Bool
3030
end

0 commit comments

Comments
 (0)