diff --git a/src/Expression.jl b/src/Expression.jl index 7f2d7dc5..e434b206 100644 --- a/src/Expression.jl +++ b/src/Expression.jl @@ -107,7 +107,7 @@ function max_degree(::Union{E,Type{E}}) where {E<:AbstractExpression} return has_node_type(E) ? max_degree(node_type(E)) : max_degree(Node) end @unstable default_node_type(_) = Node -default_node_type(::Type{N}) where {T,N<:AbstractExpression{T}} = Node{T,max_degree(N)} +@unstable default_node_type(::Type{N}) where {T,N<:AbstractExpression{T}} = Node{T} ######################################################## # Abstract interface ################################### diff --git a/src/Interfaces.jl b/src/Interfaces.jl index 4daa556e..53e5c65d 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -15,6 +15,7 @@ using ..NodeModule: with_type_parameters, with_max_degree, max_degree, + has_max_degree, unsafe_get_children, get_children, leaf_copy, @@ -144,7 +145,8 @@ function _check_default_node(ex::AbstractExpression{T}) where {T} ET = typeof(ex) E = Base.typename(ET).wrapper return default_node_type(E) <: AbstractExpressionNode && - default_node_type(ET) <: AbstractExpressionNode{T} + default_node_type(ET) <: AbstractExpressionNode{T} && + !has_max_degree(default_node_type(ET)) end function _check_constructorof(ex::AbstractExpression) return constructorof(typeof(ex)) isa Base.Callable diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 10327c1e..e17a3cf7 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -119,12 +119,9 @@ end # Abstract expression node interface ########################################## ############################################################################### @unstable constructorof(::Type{<:ParametricExpression}) = ParametricExpression -@unstable function default_node_type(::Type{<:ParametricExpression}) - return with_default_max_degree(ParametricNode) -end -function default_node_type(::Type{N}) where {T,N<:ParametricExpression{T}} - return ParametricNode{T,max_degree(N)} -end +@unstable default_node_type(::Type{<:ParametricExpression}) = ParametricNode +@unstable default_node_type(::Type{N}) where {T,N<:ParametricExpression{T}} = + ParametricNode{T} preserve_sharing(::Union{Type{<:ParametricNode},ParametricNode}) = false # COV_EXCL_LINE function leaf_copy(t::ParametricNode{T}) where {T} if t.constant diff --git a/test/test_buffered_evaluation.jl b/test/test_buffered_evaluation.jl index a1b41b9a..8ac74ff9 100644 --- a/test/test_buffered_evaluation.jl +++ b/test/test_buffered_evaluation.jl @@ -131,8 +131,11 @@ end for turbo in (false, true), i in 1:100 # Generate a random tree with varying size (1-10 nodes) - n_nodes = rand(1:10) - tree = gen_random_tree_fixed_size(n_nodes, operators, size(X, 1), Float64, Node) + rng = Random.MersenneTwister(i) + n_nodes = rand(rng, 1:10) + tree = gen_random_tree_fixed_size( + n_nodes, operators, size(X, 1), Float64, Node, rng + ) # Regular evaluation eval_options_no_buffer = EvalOptions(; turbo) @@ -142,12 +145,12 @@ end # Buffer evaluation buffer = Array{Float64}(undef, 2n_nodes, size(X, 2)) - buffer_ref = Ref(rand(1:10)) # Random starting index (will be reset) + buffer_ref = Ref(rand(rng, 1:10)) # Random starting index (will be reset) eval_options = EvalOptions(; turbo, buffer=ArrayBuffer(buffer, buffer_ref)) result2, ok2 = eval_tree_array(tree, X, operators; eval_options) # Results should be identical - @test isapprox(result1, result2; atol=1e-10) + @test isapprox(result1, result2; atol=1e-10) || (!ok1 && !ok2) @test ok1 == ok2 end end