Skip to content

Commit 633cc94

Browse files
committed
feat: get expression algebra working
1 parent 9ccb46a commit 633cc94

File tree

5 files changed

+31
-7
lines changed

5 files changed

+31
-7
lines changed

src/Expression.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
88
using ..UtilsModule: Undefined
99
using ..ChainRulesModule: NodeTangent
1010

11-
import ..NodeModule: copy_node, set_node!, count_nodes, tree_mapreduce, constructorof
11+
import ..NodeModule:
12+
copy_node, set_node!, count_nodes, tree_mapreduce, constructorof, max_degree
1213
import ..NodeUtilsModule:
1314
preserve_sharing,
1415
count_constant_nodes,
@@ -99,9 +100,14 @@ end
99100
return Expression(tree, Metadata(d))
100101
end
101102

103+
has_node_type(::Union{E,Type{E}}) where {N,E<:AbstractExpression{<:Any,N}} = true
104+
has_node_type(::Union{E,Type{E}}) where {E<:AbstractExpression} = false
102105
node_type(::Union{E,Type{E}}) where {N,E<:AbstractExpression{<:Any,N}} = N
106+
function max_degree(::Union{E,Type{E}}) where {E<:AbstractExpression}
107+
return has_node_type(E) ? max_degree(node_type(E)) : max_degree(Node)
108+
end
103109
@unstable default_node_type(_) = Node
104-
default_node_type(::Type{<:AbstractExpression{T}}) where {T} = Node{T}
110+
default_node_type(::Type{N}) where {T,N<:AbstractExpression{T}} = Node{T,max_degree(N)}
105111

106112
########################################################
107113
# Abstract interface ###################################

src/ExpressionAlgebra.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,33 @@ of the expression.
5454
"""
5555
declare_operator_alias(op::F, _) where {F<:Function} = op
5656

57+
allow_chaining(@nospecialize(op)) = false
58+
allow_chaining(::typeof(+)) = true
59+
allow_chaining(::typeof(*)) = true
60+
5761
function apply_operator(op::F, args::Vararg{Any,D}) where {F<:Function,D}
5862
idx = findfirst(e -> e isa AbstractExpression, args)::Int
5963
example_expr = args[idx]
6064
E = typeof(example_expr)
6165
@assert all(e -> !(e isa AbstractExpression) || typeof(e) === E, args)
6266
operators = get_operators(example_expr, nothing)
6367

64-
op_idx = findfirst(==(op), map(Base.Fix2(declare_operator_alias, Val(D)), operators[D]))
68+
op_idx = if length(operators) >= D
69+
findfirst(==(op), map(Base.Fix2(declare_operator_alias, Val(D)), operators[D]))
70+
else
71+
nothing
72+
end
6573
if isnothing(op_idx)
74+
if allow_chaining(op) && D > 2
75+
# These operators might get chained by Julia, so we check
76+
# downward for any matching arity.
77+
inner = apply_operator(op, args[1:(end - 1)]...)
78+
return apply_operator(op, inner, args[end])
79+
end
6680
throw(
6781
MissingOperatorError(
6882
"Operator $op not found in operators for expression type " *
69-
"$(typeof(l)) with $(D)-degree operators $(operators[D])",
83+
"$(E) with $(D)-degree operators $(operators[D])",
7084
),
7185
)
7286
end
@@ -116,7 +130,6 @@ macro declare_expression_operator(op, arity)
116130
continue
117131
end
118132

119-
120133
arglist = [Expr(:(::), syms[i], types[i]) for i in 1:arity]
121134
signature = Expr(:call, op, arglist...)
122135
if any(t -> t == :T || (t isa Expr && t.head == :curly && :T in t.args), types)

src/OperatorEnum.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ Base.copy(op::AbstractOperatorEnum) = op
4242
@unstable @inline function Base.getindex(op::AbstractOperatorEnum, i::Int)
4343
return getfield(op, :ops)[i]
4444
end
45+
@inline function Base.length(op::AbstractOperatorEnum)
46+
return length(getfield(op, :ops))
47+
end
4548
@inline function Base.getproperty(op::AbstractOperatorEnum, k::Symbol)
4649
if k == :unaops
4750
return getfield(op, :ops)[1]

src/ParametricExpression.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ function with_type_parameters(::Type{N}, ::Type{T}) where {N<:ParametricNode,T}
122122
return ParametricNode{T,max_degree(N)}
123123
end
124124
@unstable default_node_type(::Type{<:ParametricExpression}) = ParametricNode{T,2} where {T}
125-
default_node_type(::Type{<:ParametricExpression{T}}) where {T} = ParametricNode{T,2}
125+
function default_node_type(::Type{N}) where {T,N<:ParametricExpression{T}}
126+
return ParametricNode{T,max_degree(N)}
127+
end
126128
preserve_sharing(::Union{Type{<:ParametricNode},ParametricNode}) = false # TODO: Change this?
127129
function leaf_copy(t::ParametricNode{T}) where {T}
128130
if t.constant

test/test_buffered_evaluation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ end
147147
result2, ok2 = eval_tree_array(tree, X, operators; eval_options)
148148

149149
# Results should be identical
150-
@test result1 result2
150+
@test isapprox(result1, result2; atol=1e-10)
151151
@test ok1 == ok2
152152
end
153153
end

0 commit comments

Comments
 (0)