Skip to content

Commit 59c0878

Browse files
committed
feat: complete node interface for n-arity
1 parent 3905fc8 commit 59c0878

File tree

3 files changed

+38
-27
lines changed

3 files changed

+38
-27
lines changed

src/Interfaces.jl

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using ..NodeModule:
1212
constructorof,
1313
default_allocator,
1414
with_type_parameters,
15+
children,
1516
leaf_copy,
1617
leaf_convert,
1718
leaf_hash,
@@ -248,8 +249,8 @@ function _check_eltype(tree::AbstractExpressionNode{T}) where {T}
248249
end
249250
function _check_with_type_parameters(tree::AbstractExpressionNode{T}) where {T}
250251
N = typeof(tree)
251-
NT = with_type_parameters(Base.typename(N).wrapper, eltype(tree))
252-
return NT == typeof(tree)
252+
Nf16 = with_type_parameters(N, Float16)
253+
return Nf16 <: AbstractExpressionNode{Float16}
253254
end
254255
function _check_default_allocator(tree::AbstractExpressionNode)
255256
N = Base.typename(typeof(tree)).wrapper
@@ -299,35 +300,21 @@ function _check_leaf_equal(tree::AbstractExpressionNode)
299300
return leaf_equal(tree, copy(tree))
300301
end
301302
function _check_branch_copy(tree::AbstractExpressionNode)
302-
if tree.degree == 0
303-
return true
304-
elseif tree.degree == 1
305-
return branch_copy(tree, tree.l) isa typeof(tree)
306-
else
307-
return branch_copy(tree, tree.l, tree.r) isa typeof(tree)
308-
end
303+
tree.degree == 0 && return true
304+
return branch_copy(tree, children(tree, Val(tree.degree))...) isa typeof(tree)
309305
end
310306
function _check_branch_copy_into!(tree::AbstractExpressionNode{T}) where {T}
311-
if tree.degree == 0
312-
return true
313-
end
307+
tree.degree == 0 && return true
314308
new_branch = constructorof(typeof(tree))(; val=zero(T))
315-
if tree.degree == 1
316-
ret = branch_copy_into!(new_branch, tree, copy(tree.l))
317-
return new_branch == tree && ret === new_branch
318-
else
319-
ret = branch_copy_into!(new_branch, tree, copy(tree.l), copy(tree.r))
320-
return new_branch == tree && ret === new_branch
321-
end
309+
ret = branch_copy_into!(
310+
new_branch, tree, map(copy, children(tree, Val(tree.degree)))...
311+
)
312+
return new_branch == tree && ret === new_branch
322313
end
323314
function _check_branch_convert(tree::AbstractExpressionNode)
324-
if tree.degree == 0
325-
return true
326-
elseif tree.degree == 1
327-
return branch_convert(typeof(tree), tree, tree.l) isa typeof(tree)
328-
else
329-
return branch_convert(typeof(tree), tree, tree.l, tree.r) isa typeof(tree)
330-
end
315+
tree.degree == 0 && return true
316+
return branch_convert(typeof(tree), tree, children(tree, Val(tree.degree))...) isa
317+
typeof(tree)
331318
end
332319
function _check_branch_hash(tree::AbstractExpressionNode)
333320
tree.degree == 0 && return true

src/Node.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ end
225225
@inline children(node::AbstractNode) = node.children
226226
@inline function children(node::AbstractNode, ::Val{n}) where {n}
227227
cs = children(node)
228-
return ntuple(i -> cs[i], Val(n))
228+
return ntuple(i -> cs[i], Val(Int(n)))
229229
end
230230

231231
################################################################################

test/test_node_interface.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,27 @@
3333
],
3434
)
3535
end
36+
37+
@testitem "Node interface on n-arity nodes" begin
38+
using DynamicExpressions
39+
using DynamicExpressions: NodeInterface
40+
using Interfaces: Interfaces
41+
42+
for D in (3, 4, 5)
43+
x = [Node{Float64,D}(; feature=i) for i in 1:3]
44+
operator_tuple = ((sin, cos, exp), (+, *, /, -), (fma, clamp), (max, min), ())
45+
operators = OperatorEnum(operator_tuple[1:D])
46+
DynamicExpressions.OperatorEnumConstructionModule.empty_all_globals!()
47+
let tree = Node{Float64,D}(; op=2, children=(x[1], x[2])) # *
48+
if D > 2
49+
fma_idx = 1
50+
tree = Node{Float64,D}(; op=fma_idx, children=(tree, x[1], x[2])) # fma
51+
end
52+
if D > 3
53+
idx_max = 1
54+
tree = Node{Float64,D}(; op=idx_max, children=(tree, x[1], x[2], x[3])) # max
55+
end
56+
@test Interfaces.test(NodeInterface, Node, tree)
57+
end
58+
end
59+
end

0 commit comments

Comments
 (0)