Skip to content

Commit 97abbd0

Browse files
committed
feat: make differentiable eval work for n-arity
1 parent 008dfbc commit 97abbd0

File tree

1 file changed

+45
-39
lines changed

1 file changed

+45
-39
lines changed

src/Evaluate.jl

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -781,53 +781,59 @@ function differentiable_eval_tree_array(
781781
end
782782

783783
@generated function _differentiable_eval_tree_array(
784+
tree::AbstractExpressionNode{T1,D}, cX::AbstractMatrix{T}, operators::OperatorEnum
785+
)::ResultOk where {T<:Number,T1,D}
786+
quote
787+
tree.degree == 0 && return deg0_diff_eval(tree, cX, operators)
788+
op_idx = tree.op
789+
deg = tree.degree
790+
Base.Cartesian.@nif(
791+
$D,
792+
i -> i == deg,
793+
i -> dispatch_degn_diff_eval(tree, cX, op_idx, Val(i), operators)
794+
)
795+
end
796+
end
797+
798+
799+
function deg0_diff_eval(
784800
tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T}, operators::OperatorEnum
785801
)::ResultOk where {T<:Number,T1}
786-
nuna = get_nuna(operators)
787-
nbin = get_nbin(operators)
788-
quote
789-
if tree.degree == 0
790-
if tree.constant
791-
ResultOk(fill_similar(one(T), cX, axes(cX, 2)) .* tree.val, true)
792-
else
793-
ResultOk(cX[tree.feature, :], true)
794-
end
795-
elseif tree.degree == 1
796-
op_idx = tree.op
797-
Base.Cartesian.@nif(
798-
$nuna,
799-
i -> i == op_idx,
800-
i -> deg1_diff_eval(tree, cX, operators.unaops[i], operators)
801-
)
802-
else
803-
op_idx = tree.op
804-
Base.Cartesian.@nif(
805-
$nbin,
806-
i -> i == op_idx,
807-
i -> deg2_diff_eval(tree, cX, operators.binops[i], operators)
808-
)
809-
end
802+
if tree.constant
803+
ResultOk(fill_similar(one(T), cX, axes(cX, 2)) .* tree.val, true)
804+
else
805+
ResultOk(cX[tree.feature, :], true)
810806
end
811807
end
812808

813-
function deg1_diff_eval(
814-
tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum
815-
)::ResultOk where {T<:Number,F,T1}
816-
left = _differentiable_eval_tree_array(tree.l, cX, operators)
817-
!left.ok && return left
818-
out = op.(left.x)
809+
function degn_diff_eval(cumulators::C, op::F) where {C<:Tuple,F}
810+
out = op.(cumulators...)
819811
return ResultOk(out, all(isfinite, out))
820812
end
821813

822-
function deg2_diff_eval(
823-
tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum
824-
)::ResultOk where {T<:Number,F,T1}
825-
left = _differentiable_eval_tree_array(tree.l, cX, operators)
826-
!left.ok && return left
827-
right = _differentiable_eval_tree_array(tree.r, cX, operators)
828-
!right.ok && return right
829-
out = op.(left.x, right.x)
830-
return ResultOk(out, all(isfinite, out))
814+
@generated function dispatch_degn_diff_eval(
815+
tree::AbstractExpressionNode{T1,D},
816+
cX::AbstractMatrix{T},
817+
op_idx::Integer,
818+
::Val{degree},
819+
operators::OperatorEnum{OPS}
820+
) where {T<:Number,T1,D,degree,OPS}
821+
nops = length(OPS.types[degree].types)
822+
quote
823+
cs = get_children(tree, Val($degree))
824+
Base.Cartesian.@nexprs($degree, i -> begin
825+
cumulator_i = let result = _differentiable_eval_tree_array(cs[i], cX, operators)
826+
!result.ok && return result
827+
result.x
828+
end
829+
end)
830+
cumulators = Base.Cartesian.@ntuple($degree, i -> cumulator_i)
831+
Base.Cartesian.@nif(
832+
$nops,
833+
i -> i == op_idx,
834+
i -> degn_diff_eval(cumulators, operators[$degree][i])
835+
)
836+
end
831837
end
832838

833839
"""

0 commit comments

Comments
 (0)