795795 end
796796end
797797
798-
799798function deg0_diff_eval (
800799 tree:: AbstractExpressionNode{T1} , cX:: AbstractMatrix{T} , operators:: OperatorEnum
801800):: ResultOk where {T<: Number ,T1}
@@ -816,22 +815,24 @@ end
816815 cX:: AbstractMatrix{T} ,
817816 op_idx:: Integer ,
818817 :: Val{degree} ,
819- operators:: OperatorEnum{OPS}
818+ operators:: OperatorEnum{OPS} ,
820819) where {T<: Number ,T1,D,degree,OPS}
821820 nops = length (OPS. types[degree]. types)
822821 quote
823822 cs = get_children (tree, Val ($ degree))
824- Base. Cartesian. @nexprs ($ degree, i -> begin
825- cumulator_i = let result = _differentiable_eval_tree_array (cs[i], cX, operators)
826- ! result. ok && return result
827- result. x
823+ Base. Cartesian. @nexprs (
824+ $ degree,
825+ i -> begin
826+ cumulator_i =
827+ let result = _differentiable_eval_tree_array (cs[i], cX, operators)
828+ ! result. ok && return result
829+ result. x
830+ end
828831 end
829- end )
832+ )
830833 cumulators = Base. Cartesian. @ntuple ($ degree, i -> cumulator_i)
831834 Base. Cartesian. @nif (
832- $ nops,
833- i -> i == op_idx,
834- i -> degn_diff_eval (cumulators, operators[$ degree][i])
835+ $ nops, i -> i == op_idx, i -> degn_diff_eval (cumulators, operators[$ degree][i])
835836 )
836837 end
837838end
@@ -910,77 +911,89 @@ function eval(current_node)
910911 end
911912end
912913
913- @unstable function _eval_tree_array_generic (
914- tree:: AbstractExpressionNode{T1} ,
914+ @generated function _eval_tree_array_generic (
915+ tree:: AbstractExpressionNode{T1,D } ,
915916 cX:: AbstractArray{T2,N} ,
916917 operators:: GenericOperatorEnum ,
917918 :: Val{throw_errors} ,
918- ) where {T1,T2,N,throw_errors}
919- if tree. degree == 0
920- if tree. constant
921- if N == 1
922- return (tree. val:: T1 ), true
923- else
924- return fill (tree. val:: T1 , size (cX)[2 : N]), true
925- end
919+ ) where {T1,D,T2,N,throw_errors}
920+ quote
921+ tree. degree == 0 && return deg0_eval_generic (tree, cX)
922+ op_idx = tree. op
923+ deg = tree. degree
924+ Base. Cartesian. @nif (
925+ $ D,
926+ i -> i == deg,
927+ i -> dispatch_degn_eval_generic (
928+ tree, cX, op_idx, Val (i), operators, Val (throw_errors)
929+ )
930+ )
931+ end
932+ end
933+
934+ @unstable function deg0_eval_generic (
935+ tree:: AbstractExpressionNode{T1} , cX:: AbstractArray{T2,N}
936+ ) where {T1,T2,N}
937+ if tree. constant
938+ if N == 1
939+ return (tree. val:: T1 ), true
926940 else
927- if N == 1
928- return (cX[tree. feature]), true
929- else
930- return copy (selectdim (cX, 1 , tree. feature)), true
931- end
941+ return fill (tree. val:: T1 , size (cX)[2 : N]), true
932942 end
933- elseif tree. degree == 1
934- return deg1_eval_generic (
935- tree, cX, operators. unaops[tree. op], operators, Val (throw_errors)
936- )
937943 else
938- return deg2_eval_generic (
939- tree, cX, operators. binops[tree. op], operators, Val (throw_errors)
940- )
944+ if N == 1
945+ return (cX[tree. feature]), true
946+ else
947+ return copy (selectdim (cX, 1 , tree. feature)), true
948+ end
941949 end
942950end
943951
944- @unstable function deg1_eval_generic (
945- tree:: AbstractExpressionNode{T1} ,
946- cX:: AbstractArray{T2,N} ,
947- op:: F ,
948- operators:: GenericOperatorEnum ,
949- :: Val{throw_errors} ,
950- ) where {F,T1,T2,N,throw_errors}
951- left, complete = _eval_tree_array_generic (tree. l, cX, operators, Val (throw_errors))
952- ! throw_errors && ! complete && return nothing , false
953- ! throw_errors &&
954- ! hasmethod (op, N == 1 ? Tuple{typeof (left)} : Tuple{eltype (left)}) &&
955- return nothing , false
952+ @unstable function degn_eval_generic (
953+ cumulators:: C , op:: F , :: Val{N} , :: Val{throw_errors}
954+ ) where {C<: Tuple ,F,N,throw_errors}
955+ if ! throw_errors
956+ input_type = N == 1 ? C : Tuple{map (eltype, cumulators)... }
957+ ! hasmethod (op, input_type) && return nothing , false
958+ end
956959 if N == 1
957- return op (left ), true
960+ return op (cumulators ... ), true
958961 else
959- return op .(left ), true
962+ return op .(cumulators ... ), true
960963 end
961964end
962965
963- @unstable function deg2_eval_generic (
966+ @generated function dispatch_degn_eval_generic (
964967 tree:: AbstractExpressionNode{T1} ,
965968 cX:: AbstractArray{T2,N} ,
966- op:: F ,
967- operators:: GenericOperatorEnum ,
969+ op_idx:: Integer ,
970+ :: Val{degree} ,
971+ operators:: GenericOperatorEnum{OPS} ,
968972 :: Val{throw_errors} ,
969- ) where {F,T1,T2,N,throw_errors}
970- left, complete = _eval_tree_array_generic (tree. l, cX, operators, Val (throw_errors))
971- ! throw_errors && ! complete && return nothing , false
972- right, complete = _eval_tree_array_generic (tree. r, cX, operators, Val (throw_errors))
973- ! throw_errors && ! complete && return nothing , false
974- ! throw_errors &&
975- ! hasmethod (
976- op,
977- N == 1 ? Tuple{typeof (left),typeof (right)} : Tuple{eltype (left),eltype (right)},
978- ) &&
979- return nothing , false
980- if N == 1
981- return op (left, right), true
982- else
983- return op .(left, right), true
973+ ) where {T1,T2,N,degree,throw_errors,OPS}
974+ nops = length (OPS. types[degree]. types)
975+ quote
976+ cs = get_children (tree, Val ($ degree))
977+ Base. Cartesian. @nexprs (
978+ $ degree,
979+ i -> begin
980+ cumulator_i =
981+ let (x, complete) = _eval_tree_array_generic (
982+ cs[i], cX, operators, Val (throw_errors)
983+ )
984+ ! throw_errors && ! complete && return nothing , false
985+ x
986+ end
987+ end
988+ )
989+ cumulators = Base. Cartesian. @ntuple ($ degree, i -> cumulator_i)
990+ Base. Cartesian. @nif (
991+ $ nops,
992+ i -> i == op_idx,
993+ i -> degn_eval_generic (
994+ cumulators, operators[$ degree][i], Val (N), Val (throw_errors)
995+ )
996+ )
984997 end
985998end
986999
0 commit comments