@@ -781,53 +781,59 @@ function differentiable_eval_tree_array(
781781end
782782
783783@generated function _differentiable_eval_tree_array (
784+ tree:: AbstractExpressionNode{T1,D} , cX:: AbstractMatrix{T} , operators:: OperatorEnum
785+ ):: ResultOk where {T<: Number ,T1,D}
786+ quote
787+ tree. degree == 0 && return deg0_diff_eval (tree, cX, operators)
788+ op_idx = tree. op
789+ deg = tree. degree
790+ Base. Cartesian. @nif (
791+ $ D,
792+ i -> i == deg,
793+ i -> dispatch_degn_diff_eval (tree, cX, op_idx, Val (i), operators)
794+ )
795+ end
796+ end
797+
798+
799+ function deg0_diff_eval (
784800 tree:: AbstractExpressionNode{T1} , cX:: AbstractMatrix{T} , operators:: OperatorEnum
785801):: ResultOk where {T<: Number ,T1}
786- nuna = get_nuna (operators)
787- nbin = get_nbin (operators)
788- quote
789- if tree. degree == 0
790- if tree. constant
791- ResultOk (fill_similar (one (T), cX, axes (cX, 2 )) .* tree. val, true )
792- else
793- ResultOk (cX[tree. feature, :], true )
794- end
795- elseif tree. degree == 1
796- op_idx = tree. op
797- Base. Cartesian. @nif (
798- $ nuna,
799- i -> i == op_idx,
800- i -> deg1_diff_eval (tree, cX, operators. unaops[i], operators)
801- )
802- else
803- op_idx = tree. op
804- Base. Cartesian. @nif (
805- $ nbin,
806- i -> i == op_idx,
807- i -> deg2_diff_eval (tree, cX, operators. binops[i], operators)
808- )
809- end
802+ if tree. constant
803+ ResultOk (fill_similar (one (T), cX, axes (cX, 2 )) .* tree. val, true )
804+ else
805+ ResultOk (cX[tree. feature, :], true )
810806 end
811807end
812808
813- function deg1_diff_eval (
814- tree:: AbstractExpressionNode{T1} , cX:: AbstractMatrix{T} , op:: F , operators:: OperatorEnum
815- ):: ResultOk where {T<: Number ,F,T1}
816- left = _differentiable_eval_tree_array (tree. l, cX, operators)
817- ! left. ok && return left
818- out = op .(left. x)
809+ function degn_diff_eval (cumulators:: C , op:: F ) where {C<: Tuple ,F}
810+ out = op .(cumulators... )
819811 return ResultOk (out, all (isfinite, out))
820812end
821813
822- function deg2_diff_eval (
823- tree:: AbstractExpressionNode{T1} , cX:: AbstractMatrix{T} , op:: F , operators:: OperatorEnum
824- ):: ResultOk where {T<: Number ,F,T1}
825- left = _differentiable_eval_tree_array (tree. l, cX, operators)
826- ! left. ok && return left
827- right = _differentiable_eval_tree_array (tree. r, cX, operators)
828- ! right. ok && return right
829- out = op .(left. x, right. x)
830- return ResultOk (out, all (isfinite, out))
814+ @generated function dispatch_degn_diff_eval (
815+ tree:: AbstractExpressionNode{T1,D} ,
816+ cX:: AbstractMatrix{T} ,
817+ op_idx:: Integer ,
818+ :: Val{degree} ,
819+ operators:: OperatorEnum{OPS}
820+ ) where {T<: Number ,T1,D,degree,OPS}
821+ nops = length (OPS. types[degree]. types)
822+ quote
823+ cs = get_children (tree, Val ($ degree))
824+ Base. Cartesian. @nexprs ($ degree, i -> begin
825+ cumulator_i = let result = _differentiable_eval_tree_array (cs[i], cX, operators)
826+ ! result. ok && return result
827+ result. x
828+ end
829+ end )
830+ cumulators = Base. Cartesian. @ntuple ($ degree, i -> cumulator_i)
831+ Base. Cartesian. @nif (
832+ $ nops,
833+ i -> i == op_idx,
834+ i -> degn_diff_eval (cumulators, operators[$ degree][i])
835+ )
836+ end
831837end
832838
833839"""
0 commit comments