Skip to content

Commit cf92c4b

Browse files
committed
fix: various issues with n-arity parametric node
1 parent aca5395 commit cf92c4b

File tree

2 files changed

+129
-215
lines changed

2 files changed

+129
-215
lines changed

src/ParametricExpression.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,18 +313,23 @@ end
313313

314314
function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T}
315315
num_params = UInt16(size(ex.metadata.parameters, 1))
316+
tree = get_tree(ex)
317+
_NT = typeof(tree)
318+
D = max_degree(_NT)
319+
NT = with_max_degree(with_type_parameters(Node, T), Val(D))
320+
316321
return tree_mapreduce(
317322
leaf -> if leaf.constant
318-
Node(; val=leaf.val)
323+
NT(; val=leaf.val)
319324
elseif leaf.is_parameter
320-
Node(T; feature=leaf.parameter)
325+
NT(T; feature=leaf.parameter)
321326
else
322-
Node(T; feature=leaf.feature + num_params)
327+
NT(T; feature=leaf.feature + num_params)
323328
end,
324329
branch -> branch.op,
325-
(op, children...) -> Node(; op, children),
326-
get_tree(ex),
327-
Node{T},
330+
(op, children...) -> NT(; op, children),
331+
tree,
332+
NT,
328333
)
329334
end
330335
function CRC.rrule(::typeof(convert), ::Type{Node}, ex::ParametricExpression{T}) where {T}

0 commit comments

Comments
 (0)