Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DynamicExpressions"
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
authors = ["MilesCranmer <[email protected]>"]
version = "2.3.0"
version = "2.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
23 changes: 19 additions & 4 deletions src/Parse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using ..ExpressionModule:
get_operators,
get_variable_names,
node_type
using ..ExpressionAlgebraModule: declare_operator_alias

"""
@parse_expression(expr; operators, variable_names, node_type=Node, evaluate_on=[])
Expand Down Expand Up @@ -331,8 +332,13 @@ end
kws...,
)::N where {F<:Function,N<:AbstractExpressionNode,E<:AbstractExpression}
degree = length(args) - 1
if degree <= length(operators.ops) && func ∈ operators[degree]
op_idx = findfirst(==(func), operators[degree])
if degree <= length(operators.ops) && (
op_idx = findfirst(
op -> op == func || declare_operator_alias(op, Val(degree)) == func,
operators[degree],
);
!isnothing(op_idx)
)
return N(;
op=op_idx::Int,
children=map(
Expand All @@ -342,8 +348,17 @@ end
(args[2:end]...,),
),
)
elseif degree > 2 && func ∈ (+, -, *) && func ∈ operators[2]
op_idx = findfirst(==(func), operators[2])::Int
end

# Handle chaining for +, -, * operators
if degree > 2 &&
func ∈ (+, -, *) &&
(
op_idx = findfirst(
op -> op == func || declare_operator_alias(op, Val(2)) == func, operators[2]
);
!isnothing(op_idx)
)
inner = N(;
op=op_idx::Int,
children=(
Expand Down
34 changes: 34 additions & 0 deletions test/test_parse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,40 @@
end
end

@testitem "Parse with operator aliases" begin
using DynamicExpressions
using DynamicExpressions: DynamicExpressions as DE
using Test

## UNARY
safe_sqrt(x) = x < 0 ? convert(typeof(x), NaN) : sqrt(x)
DE.declare_operator_alias(::typeof(safe_sqrt), ::Val{1}) = sqrt

operators = OperatorEnum(1 => [safe_sqrt, sin, cos], 2 => [+, -, *, /])

ex = parse_expression(
"sqrt(x) + sin(y)"; operators=operators, variable_names=["x", "y"]
)

@test typeof(ex) <: Expression
@test ex.tree.op == 1
@test ex.tree.children[1].x.op == 1
@test ex.tree.children[2].x.op == 2

## BINARY
safe_pow(x, y) = x < 0 && y != round(y) ? NaN : x^y
DE.declare_operator_alias(::typeof(safe_pow), ::Val{2}) = ^

operators = OperatorEnum(1 => [sin], 2 => [+, -, safe_pow, *])
ex = parse_expression("x^2 + sin(y)"; operators=operators, variable_names=["x", "y"])

@test typeof(ex) <: Expression
@test ex.tree.op == 1
@test ex.tree.children[1].x.op == 3 # safe_pow
@test ex.tree.children[1].x.children[2].x.val == 2.0
@test ex.tree.children[2].x.op == 1
end

@testitem "Can also parse just a float" begin
using DynamicExpressions
operators = OperatorEnum() # Tests empty operators
Expand Down
Loading