Skip to content

Commit 22154d0

Browse files
committed
fix: type instabilities
1 parent 38101d1 commit 22154d0

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

ext/DynamicExpressionsSymbolicUtilsExt.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,13 @@ function Base.convert(
175175
findoperation(op, operators.unaops)
176176
end
177177

178-
return constructorof(N)(;
179-
op=ind, children=map(x -> convert(N, x, operators; variable_names), args)
180-
)
178+
if length(args) == 2
179+
children = map(x -> convert(N, x, operators; variable_names), (args[1], args[2]))
180+
return constructorof(N)(; op=ind, children)
181+
else
182+
children = map(x -> convert(N, x, operators; variable_names), (only(args),))
183+
return constructorof(N)(; op=ind, children)
184+
end
181185
end
182186

183187
_node_type(::Type{<:AbstractExpression{T,N}}) where {T,N<:AbstractExpressionNode} = N

src/Evaluate.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module EvaluateModule
22

33
using DispatchDoctor: @stable, @unstable
44

5-
import ..NodeModule: AbstractExpressionNode, constructorof, max_degree, children
5+
import ..NodeModule:
6+
AbstractExpressionNode, constructorof, max_degree, children, with_type_parameters
67
import ..StringsModule: string_tree
78
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
89
import ..UtilsModule: fill_similar, counttuple, ResultOk
@@ -244,7 +245,7 @@ function eval_tree_array(
244245
kws...,
245246
) where {T1,T2}
246247
T = promote_type(T1, T2)
247-
tree = convert(constructorof(typeof(tree)){T}, tree)
248+
tree = convert(with_type_parameters(typeof(tree), T), tree)
248249
cX = Base.Fix1(convert, T).(cX)
249250
return eval_tree_array(tree, cX, operators; kws...)
250251
end

src/ParametricExpression.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using ..ChainRulesModule: NodeTangent
1212
import ..NodeModule:
1313
constructorof,
1414
with_type_parameters,
15+
with_max_degree,
1516
max_degree,
1617
preserve_sharing,
1718
leaf_copy,
@@ -121,6 +122,9 @@ end
121122
function with_type_parameters(::Type{N}, ::Type{T}) where {N<:ParametricNode,T}
122123
return ParametricNode{T,max_degree(N)}
123124
end
125+
function with_max_degree(::Type{N}, ::Val{D}) where {T,N<:ParametricNode{T},D}
126+
return ParametricNode{T,D}
127+
end
124128
@unstable default_node_type(::Type{<:ParametricExpression}) = ParametricNode{T,2} where {T}
125129
function default_node_type(::Type{N}) where {T,N<:ParametricExpression{T}}
126130
return ParametricNode{T,max_degree(N)}

0 commit comments

Comments
 (0)