|
255 | 255 | @unstable function get_nops( |
256 | 256 | ::Type{O}, ::Val{degree} |
257 | 257 | ) where {OPS,O<:Union{OperatorEnum{OPS},GenericOperatorEnum{OPS}},degree} |
258 | | - max_degree = counttuple(OPS) |
259 | | - return degree > max_degree ? 0 : counttuple(OPS.types[degree]) |
| 258 | + return degree > counttuple(OPS) ? 0 : counttuple(OPS.types[degree]) |
260 | 259 | end |
261 | 260 |
|
262 | 261 | function _eval_tree_array( |
|
345 | 344 | end |
346 | 345 | end |
347 | 346 |
|
| 347 | +# TODO: Hack to fix type instability in some branches that can't be inferred. |
| 348 | +# It does this using the other branches, which _can_ be inferred. |
| 349 | +function _get_return_type(tree, cX, operators, eval_options) |
| 350 | + # public Julia API version of `Core.Compiler.return_type(_eval_tree_array, typeof((tree, cX, operators, eval_options)))` |
| 351 | + return eltype([_eval_tree_array(tree, cX, operators, eval_options) for _ in 1:0]) |
| 352 | +end |
| 353 | + |
348 | 354 | # This basically forms an if statement over the operators for the degree. |
349 | | -@generated function inner_dispatch_degn_eval( |
| 355 | +function inner_dispatch_degn_eval( |
| 356 | + tree::AbstractExpressionNode{T}, |
| 357 | + cX::AbstractMatrix{T}, |
| 358 | + ::Val{degree}, |
| 359 | + operators::OperatorEnum, |
| 360 | + eval_options::EvalOptions, |
| 361 | +) where {T,degree} |
| 362 | + return _inner_dispatch_degn_eval( |
| 363 | + tree, cX, Val(degree), operators, eval_options |
| 364 | + )::(_get_return_type(tree, cX, operators, eval_options)) |
| 365 | +end |
| 366 | +@generated function _inner_dispatch_degn_eval( |
350 | 367 | tree::AbstractExpressionNode{T}, |
351 | 368 | cX::AbstractMatrix{T}, |
352 | 369 | ::Val{degree}, |
|
371 | 388 | i -> i == op_idx, |
372 | 389 | i -> degn_eval( |
373 | 390 | cumulators, get_op(operators, Val($degree), Val(i)), eval_options |
374 | | - ), |
| 391 | + ) |
375 | 392 | ) |
376 | 393 | end |
377 | 394 | end |
|
0 commit comments