Skip to content

Commit b66b178

Browse files
committed
refactor: eliminate other type instabilities
1 parent b309f52 commit b66b178

File tree

3 files changed

+30
-8
lines changed

3 files changed

+30
-8
lines changed

src/Evaluate.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,7 @@ end
255255
@unstable function get_nops(
256256
::Type{O}, ::Val{degree}
257257
) where {OPS,O<:Union{OperatorEnum{OPS},GenericOperatorEnum{OPS}},degree}
258-
max_degree = counttuple(OPS)
259-
return degree > max_degree ? 0 : counttuple(OPS.types[degree])
258+
return degree > counttuple(OPS) ? 0 : counttuple(OPS.types[degree])
260259
end
261260

262261
function _eval_tree_array(
@@ -345,8 +344,26 @@ end
345344
end
346345
end
347346

347+
# TODO: Hack to fix type instability in some branches that can't be inferred.
348+
# It does this using the other branches, which _can_ be inferred.
349+
function _get_return_type(tree, cX, operators, eval_options)
350+
# public Julia API version of `Core.Compiler.return_type(_eval_tree_array, typeof((tree, cX, operators, eval_options)))`
351+
return eltype([_eval_tree_array(tree, cX, operators, eval_options) for _ in 1:0])
352+
end
353+
348354
# This basically forms an if statement over the operators for the degree.
349-
@generated function inner_dispatch_degn_eval(
355+
function inner_dispatch_degn_eval(
356+
tree::AbstractExpressionNode{T},
357+
cX::AbstractMatrix{T},
358+
::Val{degree},
359+
operators::OperatorEnum,
360+
eval_options::EvalOptions,
361+
) where {T,degree}
362+
return _inner_dispatch_degn_eval(
363+
tree, cX, Val(degree), operators, eval_options
364+
)::(_get_return_type(tree, cX, operators, eval_options))
365+
end
366+
@generated function _inner_dispatch_degn_eval(
350367
tree::AbstractExpressionNode{T},
351368
cX::AbstractMatrix{T},
352369
::Val{degree},
@@ -371,7 +388,7 @@ end
371388
i -> i == op_idx,
372389
i -> degn_eval(
373390
cumulators, get_op(operators, Val($degree), Val(i)), eval_options
374-
),
391+
)
375392
)
376393
end
377394
end

src/Node.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module NodeModule
22

33
using DispatchDoctor: @unstable
44

5-
import ..OperatorEnumModule: AbstractOperatorEnum
65
import ..UtilsModule: deprecate_varmap, Undefined
76

87
const DEFAULT_NODE_TYPE = Float32
@@ -260,14 +259,18 @@ end
260259
Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T
261260
Base.eltype(::AbstractExpressionNode{T}) where {T} = T
262261

262+
#! format: off
263263
# COV_EXCL_START
264-
max_degree(::Type{<:AbstractNode}) = DEFAULT_MAX_DEGREE
265-
max_degree(::Type{<:AbstractNode{D}}) where {D} = D
266-
max_degree(node::AbstractNode) = max_degree(typeof(node))
264+
# These are marked unstable due to issues discussed on
265+
# https://github.com/JuliaLang/julia/issues/55147
266+
@unstable max_degree(::Type{<:AbstractNode}) = DEFAULT_MAX_DEGREE
267+
@unstable max_degree(::Type{<:AbstractNode{D}}) where {D} = D
268+
@unstable max_degree(node::AbstractNode) = max_degree(typeof(node))
267269

268270
has_max_degree(::Type{<:AbstractNode}) = false
269271
has_max_degree(::Type{<:AbstractNode{D}}) where {D} = true
270272
# COV_EXCL_STOP
273+
#! format: on
271274

272275
@unstable function constructorof(::Type{N}) where {N<:Node}
273276
return Node{T,max_degree(N)} where {T}

src/Utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ function deprecate_varmap(variable_names, varMap, func_name)
2828
return variable_names
2929
end
3030

31+
# These are marked unstable due to issues discussed on
32+
# https://github.com/JuliaLang/julia/issues/55147
3133
@unstable counttuple(::Type{<:NTuple{N,Any}}) where {N} = N
3234

3335
"""

0 commit comments

Comments
 (0)