diff --git a/Project.toml b/Project.toml index b9f7e630..347cc690 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DynamicExpressions" uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b" authors = ["MilesCranmer "] -version = "1.2.0" +version = "1.3.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 7e6079f9..79f33858 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -21,6 +21,7 @@ using DispatchDoctor: @stable, @unstable include("Random.jl") include("Parse.jl") include("ParametricExpression.jl") + include("ReadOnlyNode.jl") include("StructuredExpression.jl") end @@ -92,6 +93,7 @@ import .ExpressionModule: @reexport import .ParseModule: @parse_expression, parse_expression import .ParseModule: parse_leaf @reexport import .ParametricExpressionModule: ParametricExpression, ParametricNode +import .ReadOnlyNodeModule: ReadOnlyNode @reexport import .StructuredExpressionModule: StructuredExpression import .StructuredExpressionModule: AbstractStructuredExpression diff --git a/src/Interfaces.jl b/src/Interfaces.jl index 27336ba1..be0f6020 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -52,6 +52,7 @@ using ..ExpressionModule: with_metadata, default_node_type using ..ParametricExpressionModule: ParametricExpression, ParametricNode +using ..ReadOnlyNodeModule: AbstractReadOnlyNode using ..StructuredExpressionModule: StructuredExpression ############################################################################### @@ -68,7 +69,7 @@ function _check_get_metadata(ex::AbstractExpression) return new_ex == ex && new_ex isa typeof(ex) end function _check_get_tree(ex::AbstractExpression{T,N}) where {T,N} - return get_tree(ex) isa N + return get_tree(ex) isa N || get_tree(ex) isa AbstractReadOnlyNode{T,N} end function _check_get_operators(ex::AbstractExpression) return get_operators(ex) isa AbstractOperatorEnum @@ -134,7 +135,8 @@ function _check_constructorof(ex::AbstractExpression) return constructorof(typeof(ex)) isa Base.Callable end function _check_tree_mapreduce(ex::AbstractExpression{T,N}) where {T,N} - return tree_mapreduce(node -> [node], vcat, ex) isa Vector{N} + return tree_mapreduce(node -> [node], vcat, ex) isa + (Vector{N2} where {N2<:Union{N,AbstractReadOnlyNode{T,N}}}) end #! format: off diff --git a/src/ReadOnlyNode.jl b/src/ReadOnlyNode.jl new file mode 100644 index 00000000..96978fa2 --- /dev/null +++ b/src/ReadOnlyNode.jl @@ -0,0 +1,30 @@ +module ReadOnlyNodeModule + +using ..NodeModule: AbstractExpressionNode, Node +import ..NodeModule: default_allocator, with_type_parameters, constructorof + +abstract type AbstractReadOnlyNode{T,N<:AbstractExpressionNode{T}} <: + AbstractExpressionNode{T} end + +"""A type of expression node that also stores a parameter index""" +struct ReadOnlyNode{T,N} <: AbstractReadOnlyNode{T,N} + _inner::N + + ReadOnlyNode(n::N) where {T,N<:AbstractExpressionNode{T}} = new{T,N}(n) +end +constructorof(::Type{<:ReadOnlyNode}) = ReadOnlyNode +@inline function Base.getproperty(n::AbstractReadOnlyNode, s::Symbol) + out = getproperty(getfield(n, :_inner), s) + if out isa AbstractExpressionNode + return constructorof(typeof(n))(out) + else + return out + end +end +function Base.setproperty!(::AbstractReadOnlyNode, ::Symbol, v) + return error("Cannot set properties on a ReadOnlyNode") +end +Base.propertynames(n::AbstractReadOnlyNode) = propertynames(getfield(n, :_inner)) +Base.copy(n::AbstractReadOnlyNode) = ReadOnlyNode(copy(getfield(n, :_inner))) + +end diff --git a/src/StructuredExpression.jl b/src/StructuredExpression.jl index 0282a63f..28ed6cd6 100644 --- a/src/StructuredExpression.jl +++ b/src/StructuredExpression.jl @@ -19,6 +19,7 @@ import ..ExpressionModule: node_type, get_scalar_constants, set_scalar_constants! +import ..ReadOnlyNodeModule: ReadOnlyNode abstract type AbstractStructuredExpression{ T,F<:Function,N<:AbstractExpressionNode{T},E<:AbstractExpression{T,N},D<:NamedTuple @@ -129,7 +130,7 @@ function get_metadata(e::AbstractStructuredExpression) return e.metadata end function get_tree(e::AbstractStructuredExpression) - return get_tree(get_metadata(e).structure(get_contents(e))) + return ReadOnlyNode(get_tree(get_metadata(e).structure(get_contents(e)))) end function get_operators( e::AbstractStructuredExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing diff --git a/test/test_readonlynode.jl b/test/test_readonlynode.jl new file mode 100644 index 00000000..2de8077c --- /dev/null +++ b/test/test_readonlynode.jl @@ -0,0 +1,69 @@ +@testitem "ReadOnlyNode construction and access" begin + using DynamicExpressions + using DynamicExpressions: ReadOnlyNode + + inner_node = Node{Float64}(; val=42.0) + readonly_node = ReadOnlyNode(inner_node) + + @test readonly_node isa ReadOnlyNode + @test getfield(readonly_node, :_inner) === inner_node + @test readonly_node.degree == inner_node.degree + @test readonly_node.constant == inner_node.constant + @test readonly_node.val == inner_node.val +end + +@testitem "ReadOnlyNode immutability" begin + using DynamicExpressions + using DynamicExpressions: ReadOnlyNode + + inner_node = Node{Float64}(; val=42.0) + readonly_node = ReadOnlyNode(inner_node) + + @test_throws ErrorException readonly_node.val = 100.0 + @test_throws "Cannot set properties on a ReadOnlyNode" readonly_node.val = 100.0 +end + +@testitem "ReadOnlyNode - accessing children should return ReadOnlyNode" begin + using DynamicExpressions + using DynamicExpressions: ReadOnlyNode + + operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(sin, exp)) + x1 = Node{Float64}(; feature=1) + x2 = Node{Float64}(; feature=2) + tree = 2 * x1 - sin(x2) + readonly_node = ReadOnlyNode(tree) + + @test typeof(readonly_node.l) === typeof(readonly_node) +end + +@testitem "ReadOnlyNode copy" begin + using DynamicExpressions + using DynamicExpressions: ReadOnlyNode + + inner_node = Node{Float64}(; val=42.0) + readonly_node = ReadOnlyNode(inner_node) + copied_node = copy(readonly_node) + + @test copied_node !== readonly_node + @test copied_node == readonly_node +end + +@testitem "StructuredExpression returns ReadOnlyNode" begin + using DynamicExpressions + using DynamicExpressions: ReadOnlyNode + using DynamicExpressions: StructuredExpression + + operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[-, cos, exp]) + variable_names = ["x", "y"] + kws = (; operators, variable_names) + f = parse_expression(:(x * x - cos(2.5f0 * y + -0.5f0)); kws...) + g = parse_expression(:(exp(-(y * y))); kws...) + + structured_expr = StructuredExpression((; f, g); structure=nt -> nt.f + nt.g, kws...) + + tree = get_tree(structured_expr) + @test tree isa ReadOnlyNode + @test string_tree(tree, operators; variable_names) == + "((x * x) - cos((2.5 * y) + -0.5)) + exp(-(y * y))" + @test getfield(tree, :_inner) isa Node +end diff --git a/test/unittest.jl b/test/unittest.jl index 45cb6e2e..f3ff96c8 100644 --- a/test/unittest.jl +++ b/test/unittest.jl @@ -129,3 +129,4 @@ include("test_operator_construction_edgecases.jl") include("test_node_interface.jl") include("test_expression_math.jl") include("test_structured_expression.jl") +include("test_readonlynode.jl")