diff --git a/Project.toml b/Project.toml index de366393..221dc56b 100644 --- a/Project.toml +++ b/Project.toml @@ -37,7 +37,7 @@ MacroTools = "0.4, 0.5" Optim = "0.19, 1" PrecompileTools = "1" Reexport = "1" -SymbolicUtils = "0.19, ^1.0.5, 2, 3" +SymbolicUtils = "4" Zygote = "0.7" julia = "1.10" diff --git a/ext/DynamicExpressionsSymbolicUtilsExt.jl b/ext/DynamicExpressionsSymbolicUtilsExt.jl index 30a987d3..666a2312 100644 --- a/ext/DynamicExpressionsSymbolicUtilsExt.jl +++ b/ext/DynamicExpressionsSymbolicUtilsExt.jl @@ -8,19 +8,14 @@ using DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum using DynamicExpressions.UtilsModule: deprecate_varmap using SymbolicUtils +using SymbolicUtils: BasicSymbolic, SymReal, iscall, issym, isconst, unwrap_const, term import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node import DynamicExpressions.ValueInterfaceModule: is_valid -const SYMBOLIC_UTILS_TYPES = Union{<:Number,SymbolicUtils.Symbolic{<:Number}} +const SYMBOLIC_UTILS_TYPES = Union{<:Number,BasicSymbolic} const SUPPORTED_OPS = (cos, sin, exp, cot, tan, csc, sec, +, -, *, /) -@static if isdefined(SymbolicUtils, :iscall) - iscall(x) = SymbolicUtils.iscall(x) -else - iscall(x) = SymbolicUtils.istree(x) -end - macro return_on_false(flag, retval) :( if !$(esc(flag)) @@ -29,7 +24,7 @@ macro return_on_false(flag, retval) ) end -function is_valid(x::SymbolicUtils.Symbolic) +function is_valid(x::BasicSymbolic) return if iscall(x) all(is_valid.([SymbolicUtils.operation(x); SymbolicUtils.arguments(x)])) else @@ -38,6 +33,17 @@ function is_valid(x::SymbolicUtils.Symbolic) end subs_bad(x) = is_valid(x) ? x : Inf +function _unwrap_const_number(expr::BasicSymbolic) + val = unwrap_const(expr) + if val isa Number + return val + elseif val isa AbstractArray{<:Number} && length(val) == 1 + return only(val) + else + error("Unsupported constant type in SymbolicUtils conversion: $(typeof(val))") + end +end + function parse_tree_to_eqs( tree::AbstractExpressionNode{T}, operators::AbstractOperatorEnum, @@ -46,40 +52,62 @@ function parse_tree_to_eqs( if tree.degree == 0 # Return constant if needed tree.constant && return subs_bad(tree.val) - return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)")) + return SymbolicUtils.Sym{SymReal}(Symbol("x$(tree.feature)"); type=Number) end # Collect the next children # TODO: Type instability! children = tree.degree == 2 ? (tree.l, tree.r) : (tree.l,) # Get the operation op = tree.degree == 2 ? operators.binops[tree.op] : operators.unaops[tree.op] - # Create an N tuple of Numbers for each argument - dtypes = map(x -> Number, 1:(tree.degree)) + + # For custom operators, SymbolicUtils can't represent the Julia function directly. + # When `index_functions=true`, represent the operator by its name as an uninterpreted + # SymbolicUtils function symbol so we can round-trip back to a DynamicExpressions node. + if !(op ∈ SUPPORTED_OPS) + if index_functions + dtypes = ntuple(_ -> Number, tree.degree) + op = SymbolicUtils.Sym{SymReal}( + Symbol(op); type=SymbolicUtils.FnType{Tuple{dtypes...},Number,Nothing} + ) + else + error( + "Custom operator '$op' is not supported with SymbolicUtils unless " * + "index_functions=true. Supported operators without indexing: $SUPPORTED_OPS", + ) + end + end + + # Convert children to symbolic form + sym_children = map(x -> parse_tree_to_eqs(x, operators, index_functions), children) + + # SymbolicUtils v4 may canonicalize some commutative operations at construction time + # (e.g. `x*x` -> `x^2`, or reordering `a + b`). # - if !(op ∈ SUPPORTED_OPS) && index_functions - op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{dtypes...},Number}}(Symbol(op)) + # For stable round-trips (and to avoid introducing `^` when it's not in the operator set), + # construct commutative ops as explicit terms. + if op === (*) || op === (+) + return subs_bad(term(op, sym_children...)) end - return subs_bad( - op(map(x -> parse_tree_to_eqs(x, operators, index_functions), children)...) - ) + return subs_bad(op(sym_children...)) end -# For operators which are indexed, we need to convert them back -# using the string: -function convert_to_function( - x::SymbolicUtils.Sym{SymbolicUtils.FnType{T,Number}}, operators::AbstractOperatorEnum -) where {T<:Tuple} - degree = length(T.types) - if degree == 1 - ind = findoperation(x.name, operators.unaops) - return operators.unaops[ind] - elseif degree == 2 - ind = findoperation(x.name, operators.binops) - return operators.binops[ind] - else - throw(AssertionError("Function $(String(x.name)) has degree > 2 !")) +function convert_to_function(x::BasicSymbolic, operators::AbstractOperatorEnum) + if issym(x) && SymbolicUtils.symtype(x) <: SymbolicUtils.FnType + signature, _ = SymbolicUtils.fntype_X_Y(SymbolicUtils.symtype(x)) + degree = length(signature.parameters) + name = nameof(x) + if degree == 1 + ind = findoperation(name, operators.unaops) + return operators.unaops[ind] + elseif degree == 2 + ind = findoperation(name, operators.binops) + return operators.binops[ind] + else + throw(AssertionError("Function $(String(name)) has degree > 2 !")) + end end + return x end # For normal operators, simply return the function itself: @@ -90,7 +118,7 @@ function split_eq( op, args, operators::AbstractOperatorEnum, - ::Type{N}=Node; + (::Type{N})=Node; variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, # Deprecated: varMap=nothing, @@ -119,7 +147,7 @@ function findoperation(op, ops) end function Base.convert( - ::typeof(SymbolicUtils.Symbolic), + ::typeof(BasicSymbolic), tree::Union{AbstractExpression,AbstractExpressionNode}, operators::Union{AbstractOperatorEnum,Nothing}=nothing; variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, @@ -141,14 +169,20 @@ end function Base.convert( ::Type{N}, - expr::SymbolicUtils.Symbolic, + expr::BasicSymbolic, operators::AbstractOperatorEnum; variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, ) where {N<:AbstractExpressionNode} variable_names = deprecate_varmap(variable_names, nothing, :convert) - if !iscall(expr) + # Handle constants (v4 wraps numbers in Const variant) + if isconst(expr) + return constructorof(N)(; val=DEFAULT_NODE_TYPE(_unwrap_const_number(expr))) + end + # Handle symbols (variables) + if issym(expr) + exprname = nameof(expr) if variable_names === nothing - s = String(expr.name) + s = String(exprname) # Verify it is of the format "x{num}": @assert( occursin(r"^x\d+$", s), @@ -156,7 +190,11 @@ function Base.convert( ) return constructorof(N)(s) end - return constructorof(N)(String(expr.name), variable_names) + return constructorof(N)(String(exprname), variable_names) + end + # Handle function calls + if !iscall(expr) + error("Unknown symbolic expression type: $(typeof(expr))") end # First, we remove integer powers: @@ -165,6 +203,14 @@ function Base.convert( expr = y end + # `multiply_powers` may simplify to an atom (e.g. `x^0` -> `1.0`). Re-handle atoms + # before calling `SymbolicUtils.operation`. + if expr isa Number + return convert(N, expr, operators; variable_names) + elseif expr isa BasicSymbolic && !iscall(expr) + return convert(N, expr, operators; variable_names) + end + op = convert_to_function(SymbolicUtils.operation(expr), operators) args = SymbolicUtils.arguments(expr) @@ -185,7 +231,7 @@ _node_type(::Type{E}) where {E<:AbstractExpression} = default_node_type(E) function Base.convert( ::Type{E}, - x::Union{SymbolicUtils.Symbolic,Number}, + x::Union{BasicSymbolic,Number}, operators::AbstractOperatorEnum; variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, kws..., @@ -209,10 +255,9 @@ will generate a symbolic equation in SymbolicUtils.jl format. - `operators::AbstractOperatorEnum`: OperatorEnum, which contains the operators used in the equation. - `variable_names::Union{AbstractVector{<:AbstractString}, Nothing}=nothing`: What variable names to use for each feature. Default is [x1, x2, x3, ...]. -- `index_functions::Bool=false`: Whether to generate special names for the - operators, which then allows one to convert back to a `AbstractExpressionNode` format - using `symbolic_to_node`. - (CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84). +- `index_functions::Bool=false`: Whether to represent custom operators by name as + uninterpreted SymbolicUtils function symbols. This allows round-tripping back to a + `AbstractExpressionNode` using `symbolic_to_node`. """ function node_to_symbolic( tree::AbstractExpressionNode, @@ -231,8 +276,8 @@ function node_to_symbolic( # Create a substitution tuple subs = Dict( [ - SymbolicUtils.Sym{LiteralReal}(Symbol("x$(i)")) => - SymbolicUtils.Sym{LiteralReal}(Symbol(variable_names[i])) for + SymbolicUtils.Sym{SymReal}(Symbol("x$(i)"); type=Number) => + SymbolicUtils.Sym{SymReal}(Symbol(variable_names[i]); type=Number) for i in 1:length(variable_names) ]..., ) @@ -253,9 +298,9 @@ function node_to_symbolic( end function symbolic_to_node( - eqn::SymbolicUtils.Symbolic, + eqn::BasicSymbolic, operators::AbstractOperatorEnum, - ::Type{N}=Node; + (::Type{N})=Node; variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, # Deprecated: varMap=nothing, @@ -268,7 +313,7 @@ function multiply_powers(eqn::Number)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} return eqn, true end -function multiply_powers(eqn::SymbolicUtils.Symbolic)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} +function multiply_powers(eqn::BasicSymbolic)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} if !iscall(eqn) return eqn, true end @@ -277,7 +322,7 @@ function multiply_powers(eqn::SymbolicUtils.Symbolic)::Tuple{SYMBOLIC_UTILS_TYPE end function multiply_powers( - eqn::SymbolicUtils.Symbolic, op::F + eqn::BasicSymbolic, op::F )::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {F} args = SymbolicUtils.arguments(eqn) nargs = length(args) @@ -291,15 +336,34 @@ function multiply_powers( @return_on_false complete eqn @return_on_false is_valid(l) eqn n = args[2] - if typeof(n) <: Integer - if n == 1 + # In SymbolicUtils v4, integer constants are wrapped in Const + n_val = if isconst(n) + _unwrap_const_number(n) + elseif typeof(n) <: Integer + n + else + nothing + end + if n_val !== nothing && typeof(n_val) <: Integer + if n_val == 1 return l, true - elseif n == -1 - return 1.0 / l, true - elseif n > 1 - return reduce(*, [l for i in 1:n]), true - elseif n < -1 - return reduce(/, vcat([1], [l for i in 1:abs(n)])), true + elseif n_val == -1 + return term(/, 1.0, l), true + elseif n_val > 1 + # IMPORTANT: use `term(*, ...)` to prevent SymbolicUtils from immediately + # canonicalizing `l*l` back into `l^2`. + out = l + for _ in 2:n_val + out = term(*, out, l) + end + return out, true + elseif n_val < -1 + # Build 1/(l*l*...) using explicit multiplication terms. + denom = l + for _ in 2:abs(n_val) + denom = term(*, denom, l) + end + return term(/, 1.0, denom), true else return 1.0, true end @@ -315,6 +379,11 @@ function multiply_powers( r, complete2 = multiply_powers(args[2]) @return_on_false complete2 eqn @return_on_false is_valid(r) eqn + # SymbolicUtils v4 normalizes `x*x` into `x^2` via the `*` method; preserve + # explicit multiplication terms so we don't introduce `^` during conversion. + if op == * + return term(op, l, r), true + end return op(l, r), true else # return tree_mapreduce(multiply_powers, op, args) @@ -326,7 +395,8 @@ function multiply_powers( end cumulator = out[1][1] for i in 2:size(out, 1) - cumulator = op(cumulator, out[i][1]) + cumulator = + (op == *) ? term(op, cumulator, out[i][1]) : op(cumulator, out[i][1]) @return_on_false is_valid(cumulator) eqn end return cumulator, true diff --git a/src/precompile.jl b/src/precompile.jl index d16bc6b7..ceaa38cb 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -91,11 +91,12 @@ end function test_functions_on_trees(::Type{T}, operators) where {T} local x, c, tree + tree = Node(Float64; val=0.0) num_unaops = length(operators.unaops) num_binops = length(operators.binops) @assert num_unaops > 0 && num_binops > 0 - for T1 in [Float32, Float64] + for T1 in (Float32, Float64) x = Node(T1; feature=1) c = Node(T1; val=T1(1.0)) diff --git a/test/runtests.jl b/test/runtests.jl index c24811a3..cd117022 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,27 +21,36 @@ if "jet" in test_name set_preferences!("DynamicExpressions", "instability_check" => "disable"; force=true) using JET using DynamicExpressions - struct MyIgnoredModule - mod::Module - end - function JET.match_module( - mod::MyIgnoredModule, @nospecialize(report::JET.InferenceErrorReport) - ) - s_mod = string(mod.mod) - any(report.vst) do vst - occursin(s_mod, string(JET.linfomod(vst.linfo))) - end - end + if VERSION >= v"1.10" - JET.test_package( - DynamicExpressions; - target_defined_modules=true, - ignored_modules=( - MyIgnoredModule(DynamicExpressions.NonDifferentiableDeclarationsModule), - ), - ) - # TODO: Hack to get JET to ignore modules - # https://github.com/aviatesk/JET.jl/issues/570#issuecomment-2199167755 + # JET's keyword API has changed across versions. + # Prefer the older (but still supported) configuration first. + try + JET.test_package( + DynamicExpressions; + target_defined_modules=true, + ignored_modules=( + DynamicExpressions.NonDifferentiableDeclarationsModule, + DynamicExpressions.ValueInterfaceModule, + DynamicExpressions.OperatorEnumConstructionModule, + ), + ) + catch err + if err isa MethodError + # Newer JET prefers explicit target_modules. + JET.test_package( + DynamicExpressions; + target_modules=(DynamicExpressions,), + ignored_modules=( + DynamicExpressions.NonDifferentiableDeclarationsModule, + DynamicExpressions.ValueInterfaceModule, + DynamicExpressions.OperatorEnumConstructionModule, + ), + ) + else + rethrow() + end + end end end end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 435ef55e..393dd31e 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -2,5 +2,9 @@ using DynamicExpressions using Aqua if VERSION >= v"1.9" - Aqua.test_all(DynamicExpressions; project_toml_formatting=false) + # Aqua's piracy check relies on some Julia internals that changed in Julia 1.12, + # which can cause a hard error (FieldError: Core.TypeName has no field `mt`). + # We still run piracy checks on older Julia versions. + piracy_ok = VERSION < v"1.12.0-" + Aqua.test_all(DynamicExpressions; project_toml_formatting=false, piracy=piracy_ok) end diff --git a/test/test_chainrules.jl b/test/test_chainrules.jl index 3b721684..61616f8d 100644 --- a/test/test_chainrules.jl +++ b/test/test_chainrules.jl @@ -102,9 +102,9 @@ let @extend_operators operators x1 = Node(Float64; feature=1) - nan_forward = bad_op(x1 + 0.5) - undefined_grad = undefined_grad_op(x1 + 0.5) - nan_grad = bad_grad_op(x1) + nan_forward = @eval bad_op($(x1 + 0.5)) + undefined_grad = @eval undefined_grad_op($(x1 + 0.5)) + nan_grad = @eval bad_grad_op($x1) function eval_tree(X, tree) y, _ = eval_tree_array(tree, X, operators) diff --git a/test/test_simplification.jl b/test/test_simplification.jl index 366d72ba..2c567f56 100644 --- a/test/test_simplification.jl +++ b/test/test_simplification.jl @@ -1,7 +1,7 @@ include("test_params.jl") using DynamicExpressions, Test import DynamicExpressions.StringsModule: strip_brackets -import SymbolicUtils: simplify, Symbolic +import SymbolicUtils: simplify, BasicSymbolic import Random: MersenneTwister import Base: ≈ @@ -27,7 +27,7 @@ operators = OperatorEnum(; binary_operators=binary_operators) tree = Node("x1") + Node("x1") # Should simplify to 2*x1: -eqn = convert(Symbolic, tree, operators) +eqn = convert(BasicSymbolic, tree, operators) eqn2 = simplify(eqn) # Should correctly simplify to 2 x1: # (although it might use 2(x1^1)) @@ -44,7 +44,7 @@ tree = convert(Node, eqn2, operators) # Finally, let's try converting a product, and ensure # that SymbolicUtils does not convert it to a power: tree = Node("x1") * Node("x1") -eqn = convert(Symbolic, tree, operators) +eqn = convert(BasicSymbolic, tree, operators) @test repr(eqn) ≈ "x1*x1" # Test converting back: tree_copy = convert(Node, eqn, operators) @@ -70,7 +70,7 @@ tree = ( ) ) # We use `index_functions` to avoid converting the custom operators into the primitives. -eqn = convert(Symbolic, tree, operators; index_functions=true) +eqn = convert(BasicSymbolic, tree, operators; index_functions=true) tree_copy = convert(Node, eqn, operators) tree_copy2 = convert(Node, simplify(eqn), operators) diff --git a/test/test_symbolic_utils.jl b/test/test_symbolic_utils.jl index f85a0cfc..0e634ef6 100644 --- a/test/test_symbolic_utils.jl +++ b/test/test_symbolic_utils.jl @@ -1,34 +1,31 @@ using SymbolicUtils +using StaticArrays using DynamicExpressions using DynamicExpressions: get_operators, get_variable_names using Test include("test_params.jl") -_inv(x) = 1 / x -!(@isdefined safe_pow) && - @eval safe_pow(x::T, y::T) where {T<:Number} = (x < 0 && y != round(y)) ? T(NaN) : x^y -!(@isdefined greater) && @eval greater(x::T, y::T) where {T} = (x > y) ? one(T) : zero(T) - -tree = - let tmp_op = OperatorEnum(; - default_params..., - binary_operators=(+, *, ^, /, greater), - unary_operators=(_inv,), - ) - Node(5, (Node(; val=3.0) * Node(1, Node("x1")))^2.0, Node(; val=-1.2)) - end - +# Test basic conversion with supported operators only operators = OperatorEnum(; - default_params..., - binary_operators=(+, *, safe_pow, /, greater), - unary_operators=(_inv,), + default_params..., binary_operators=(+, *, -, /), unary_operators=(sin, cos, exp) ) -eqn = node_to_symbolic(tree, operators; variable_names=["energy"], index_functions=true) -@test string(eqn) == "greater(safe_pow(3.0_inv(energy), 2.0), -1.2)" +# Build tree: sin(3.0 * x1) + 2.0 +x1_node = Node(; feature=1) +tree = Node(1, Node(; val=3.0) * x1_node) + Node(; val=2.0) + +eqn = node_to_symbolic(tree, operators; variable_names=["energy"]) tree2 = symbolic_to_node(eqn, operators; variable_names=["energy"]) -@test string_tree(tree, operators) == string_tree(tree2, operators) +# SymbolicUtils v4 may reorder commutative operations, so compare by evaluation +X = reshape([1.5, -0.2, 2.0], 1, :) +expected = sin.(3.0 .* X[1, :]) .+ 2.0 + +result1, ok1 = eval_tree_array(tree, X, operators) +result2, ok2 = eval_tree_array(tree2, X, operators) +@test ok1 && ok2 +@test isapprox(result1, result2) +@test isapprox(result2, expected; rtol=0, atol=1.0e-12) # Test variable name conversion with Expression objects let @@ -39,23 +36,124 @@ let variable_names=["x", "y"], ) - # Test conversion to symbolic form preserves variable names - eqn = convert(SymbolicUtils.Symbolic, ex) - @test string(eqn) == "sin(x + y)" + # Test conversion to symbolic form round-trips by evaluation. + eqn = convert(SymbolicUtils.BasicSymbolic, ex) + operators_roundtrip = OperatorEnum(; + unary_operators=(sin,), binary_operators=(+, *, -, /) + ) + ex_again = convert(Expression, eqn, operators_roundtrip; variable_names=["x", "y"]) - # Test with different variable names in the expression + X = rand(Float64, 2, 10) .+ 1 + y1, ok1 = eval_tree_array(ex, X) + y2, ok2 = eval_tree_array(ex_again, X) + @test ok1 && ok2 + @test y1 ≈ y2 + + # Test with different variable names in the expression. + # Use a non-symmetric expression so we can detect any variable swapping, + # while still allowing commutative re-ordering inside SymbolicUtils. ex2 = parse_expression( - :(sin(alpha + beta)); + :(sin(alpha + 2 * beta)); binary_operators=[+, *, -, /], unary_operators=[sin], variable_names=["alpha", "beta"], ) - eqn2 = convert(SymbolicUtils.Symbolic, ex2) - @test string(eqn2) == "sin(alpha + beta)" - eqn2 + eqn2 = convert(SymbolicUtils.BasicSymbolic, ex2) - # Test round trip preserves structure and variable names + # Test round trip preserves semantics and variable names. + # SymbolicUtils v4 may reorder commutative operations, so don't require exact `==`. operators = OperatorEnum(; unary_operators=(sin,), binary_operators=(+, *, -, /)) ex2_again = convert(Expression, eqn2, operators; variable_names=["alpha", "beta"]) - @test ex2 == ex2_again + + X = rand(Float64, 2, 10) .+ 1 + y1, _ = eval_tree_array(ex2, X) + y2, _ = eval_tree_array(ex2_again, X) + @test y1 ≈ y2 +end + +# Const scalar-container unwrapping (SymbolicUtils v4 Const can wrap scalar containers) +let + operators = OperatorEnum(; + default_params..., binary_operators=(+, *, -, /), unary_operators=() + ) + expr = SymbolicUtils.BasicSymbolicImpl.Const{SymbolicUtils.SymReal}(SVector(1.0)) + + node = convert(Node, expr, operators) + + X = rand(Float64, 1, 10) .+ 1 + y, ok = eval_tree_array(node, X, operators) + @test ok + @test all(y .== 1.0f0) +end + +# Operators excluding `^`: preserve semantics for `x*x` and integer powers +let + operators = OperatorEnum(; + default_params..., binary_operators=(+, *, -, /), unary_operators=() + ) + + ex = parse_expression( + :(x * x); + binary_operators=[+, *, -, /], + unary_operators=Function[], + variable_names=["x"], + ) + eqn = convert(SymbolicUtils.BasicSymbolic, ex) + ex_again = convert(Expression, eqn, operators; variable_names=["x"]) + + X = rand(Float64, 1, 50) .+ 0.1 + y1, ok1 = eval_tree_array(ex, X) + y2, ok2 = eval_tree_array(ex_again, X) + @test ok1 && ok2 + @test y1 ≈ y2 + + x = SymbolicUtils.Sym{SymbolicUtils.SymReal}(:x1; type=Number) + for n in (0, 1, 2) + expr = SymbolicUtils.term(^, x, n) + node = convert(Node, expr, operators) + + # SU -> Node -> SU -> Node round-trip should not require `^` in the operator set + eqn_rt = node_to_symbolic(node, operators) + node_rt = if eqn_rt isa Number + convert(Node, eqn_rt, operators) + else + symbolic_to_node(eqn_rt, operators) + end + + X = rand(Float64, 1, 40) .+ 0.1 + y, ok = eval_tree_array(node_rt, X, operators) + @test ok + @test y ≈ Float32.(X[1, :] .^ n) + end + + x2 = SymbolicUtils.Sym{SymbolicUtils.SymReal}(:x2; type=Number) + x3 = SymbolicUtils.Sym{SymbolicUtils.SymReal}(:x3; type=Number) + expr = SymbolicUtils.term(*, x, x2, x3) + node = convert(Node, expr, operators) + + X = rand(Float64, 3, 40) .+ 0.1 + y, ok = eval_tree_array(node, X, operators) + @test ok + @test y ≈ Float32.(X[1, :] .* X[2, :] .* X[3, :]) +end + +# Test custom operator round-trip via index_functions +let + myop(x, y) = x + 2y + operators = OperatorEnum(; binary_operators=(+, *, -, /, myop), unary_operators=(sin,)) + + x1 = Node(; feature=1) + x2 = Node(; feature=2) + tree = Node(; op=5, l=x1, r=x2) # myop(x1, x2) + + @test_throws ErrorException node_to_symbolic(tree, operators; index_functions=false) + + eqn = node_to_symbolic(tree, operators; index_functions=true) + tree2 = symbolic_to_node(eqn, operators) + + X = rand(Float64, 2, 50) .+ 1 + y1, ok1 = eval_tree_array(tree, X, operators) + y2, ok2 = eval_tree_array(tree2, X, operators) + @test ok1 && ok2 + @test y1 ≈ y2 end