Skip to content

Commit 1e672bc

Browse files
committed
fix: some issues with D-degree ParametricNode
1 parent 2a0bd05 commit 1e672bc

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/ParametricExpression.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
88
using ..ExpressionModule: AbstractExpression, Metadata
99
using ..ChainRulesModule: NodeTangent
1010

11-
import ..NodeModule: constructorof, preserve_sharing, leaf_copy, leaf_hash, leaf_equal
11+
import ..NodeModule: constructorof, max_degree, preserve_sharing, leaf_copy, leaf_hash, leaf_equal
1212
import ..NodeUtilsModule:
1313
count_constant_nodes,
1414
index_constant_nodes,
@@ -96,10 +96,10 @@ end
9696
###############################################################################
9797
# Abstract expression node interface ##########################################
9898
###############################################################################
99-
@unstable constructorof(::Type{<:ParametricNode}) = ParametricNode
99+
@unstable constructorof(::Type{N}) where {N<:ParametricNode} = ParametricNode{T,max_degree(N)} where {T}
100100
@unstable constructorof(::Type{<:ParametricExpression}) = ParametricExpression
101-
@unstable default_node_type(::Type{<:ParametricExpression}) = ParametricNode
102-
default_node_type(::Type{<:ParametricExpression{T}}) where {T} = ParametricNode{T}
101+
@unstable default_node_type(::Type{<:ParametricExpression}) = ParametricNode{T,2} where {T}
102+
default_node_type(::Type{<:ParametricExpression{T}}) where {T} = ParametricNode{T,2}
103103
preserve_sharing(::Union{Type{<:ParametricNode},ParametricNode}) = false # TODO: Change this?
104104
function leaf_copy(t::ParametricNode{T}) where {T}
105105
out = if t.constant

0 commit comments

Comments
 (0)