Skip to content

Commit 067734a

Browse files
committed
feat: make generic eval allow n-arity nodes
1 parent 97abbd0 commit 067734a

File tree

1 file changed

+77
-64
lines changed

1 file changed

+77
-64
lines changed

src/Evaluate.jl

Lines changed: 77 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,6 @@ end
795795
end
796796
end
797797

798-
799798
function deg0_diff_eval(
800799
tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T}, operators::OperatorEnum
801800
)::ResultOk where {T<:Number,T1}
@@ -816,22 +815,24 @@ end
816815
cX::AbstractMatrix{T},
817816
op_idx::Integer,
818817
::Val{degree},
819-
operators::OperatorEnum{OPS}
818+
operators::OperatorEnum{OPS},
820819
) where {T<:Number,T1,D,degree,OPS}
821820
nops = length(OPS.types[degree].types)
822821
quote
823822
cs = get_children(tree, Val($degree))
824-
Base.Cartesian.@nexprs($degree, i -> begin
825-
cumulator_i = let result = _differentiable_eval_tree_array(cs[i], cX, operators)
826-
!result.ok && return result
827-
result.x
823+
Base.Cartesian.@nexprs(
824+
$degree,
825+
i -> begin
826+
cumulator_i =
827+
let result = _differentiable_eval_tree_array(cs[i], cX, operators)
828+
!result.ok && return result
829+
result.x
830+
end
828831
end
829-
end)
832+
)
830833
cumulators = Base.Cartesian.@ntuple($degree, i -> cumulator_i)
831834
Base.Cartesian.@nif(
832-
$nops,
833-
i -> i == op_idx,
834-
i -> degn_diff_eval(cumulators, operators[$degree][i])
835+
$nops, i -> i == op_idx, i -> degn_diff_eval(cumulators, operators[$degree][i])
835836
)
836837
end
837838
end
@@ -910,77 +911,89 @@ function eval(current_node)
910911
end
911912
end
912913

913-
@unstable function _eval_tree_array_generic(
914-
tree::AbstractExpressionNode{T1},
914+
@generated function _eval_tree_array_generic(
915+
tree::AbstractExpressionNode{T1,D},
915916
cX::AbstractArray{T2,N},
916917
operators::GenericOperatorEnum,
917918
::Val{throw_errors},
918-
) where {T1,T2,N,throw_errors}
919-
if tree.degree == 0
920-
if tree.constant
921-
if N == 1
922-
return (tree.val::T1), true
923-
else
924-
return fill(tree.val::T1, size(cX)[2:N]), true
925-
end
919+
) where {T1,D,T2,N,throw_errors}
920+
quote
921+
tree.degree == 0 && return deg0_eval_generic(tree, cX)
922+
op_idx = tree.op
923+
deg = tree.degree
924+
Base.Cartesian.@nif(
925+
$D,
926+
i -> i == deg,
927+
i -> dispatch_degn_eval_generic(
928+
tree, cX, op_idx, Val(i), operators, Val(throw_errors)
929+
)
930+
)
931+
end
932+
end
933+
934+
@unstable function deg0_eval_generic(
935+
tree::AbstractExpressionNode{T1}, cX::AbstractArray{T2,N}
936+
) where {T1,T2,N}
937+
if tree.constant
938+
if N == 1
939+
return (tree.val::T1), true
926940
else
927-
if N == 1
928-
return (cX[tree.feature]), true
929-
else
930-
return copy(selectdim(cX, 1, tree.feature)), true
931-
end
941+
return fill(tree.val::T1, size(cX)[2:N]), true
932942
end
933-
elseif tree.degree == 1
934-
return deg1_eval_generic(
935-
tree, cX, operators.unaops[tree.op], operators, Val(throw_errors)
936-
)
937943
else
938-
return deg2_eval_generic(
939-
tree, cX, operators.binops[tree.op], operators, Val(throw_errors)
940-
)
944+
if N == 1
945+
return (cX[tree.feature]), true
946+
else
947+
return copy(selectdim(cX, 1, tree.feature)), true
948+
end
941949
end
942950
end
943951

944-
@unstable function deg1_eval_generic(
945-
tree::AbstractExpressionNode{T1},
946-
cX::AbstractArray{T2,N},
947-
op::F,
948-
operators::GenericOperatorEnum,
949-
::Val{throw_errors},
950-
) where {F,T1,T2,N,throw_errors}
951-
left, complete = _eval_tree_array_generic(tree.l, cX, operators, Val(throw_errors))
952-
!throw_errors && !complete && return nothing, false
953-
!throw_errors &&
954-
!hasmethod(op, N == 1 ? Tuple{typeof(left)} : Tuple{eltype(left)}) &&
955-
return nothing, false
952+
@unstable function degn_eval_generic(
953+
cumulators::C, op::F, ::Val{N}, ::Val{throw_errors}
954+
) where {C<:Tuple,F,N,throw_errors}
955+
if !throw_errors
956+
input_type = N == 1 ? C : Tuple{map(eltype, cumulators)...}
957+
!hasmethod(op, input_type) && return nothing, false
958+
end
956959
if N == 1
957-
return op(left), true
960+
return op(cumulators...), true
958961
else
959-
return op.(left), true
962+
return op.(cumulators...), true
960963
end
961964
end
962965

963-
@unstable function deg2_eval_generic(
966+
@generated function dispatch_degn_eval_generic(
964967
tree::AbstractExpressionNode{T1},
965968
cX::AbstractArray{T2,N},
966-
op::F,
967-
operators::GenericOperatorEnum,
969+
op_idx::Integer,
970+
::Val{degree},
971+
operators::GenericOperatorEnum{OPS},
968972
::Val{throw_errors},
969-
) where {F,T1,T2,N,throw_errors}
970-
left, complete = _eval_tree_array_generic(tree.l, cX, operators, Val(throw_errors))
971-
!throw_errors && !complete && return nothing, false
972-
right, complete = _eval_tree_array_generic(tree.r, cX, operators, Val(throw_errors))
973-
!throw_errors && !complete && return nothing, false
974-
!throw_errors &&
975-
!hasmethod(
976-
op,
977-
N == 1 ? Tuple{typeof(left),typeof(right)} : Tuple{eltype(left),eltype(right)},
978-
) &&
979-
return nothing, false
980-
if N == 1
981-
return op(left, right), true
982-
else
983-
return op.(left, right), true
973+
) where {T1,T2,N,degree,throw_errors,OPS}
974+
nops = length(OPS.types[degree].types)
975+
quote
976+
cs = get_children(tree, Val($degree))
977+
Base.Cartesian.@nexprs(
978+
$degree,
979+
i -> begin
980+
cumulator_i =
981+
let (x, complete) = _eval_tree_array_generic(
982+
cs[i], cX, operators, Val(throw_errors)
983+
)
984+
!throw_errors && !complete && return nothing, false
985+
x
986+
end
987+
end
988+
)
989+
cumulators = Base.Cartesian.@ntuple($degree, i -> cumulator_i)
990+
Base.Cartesian.@nif(
991+
$nops,
992+
i -> i == op_idx,
993+
i -> degn_eval_generic(
994+
cumulators, operators[$degree][i], Val(N), Val(throw_errors)
995+
)
996+
)
984997
end
985998
end
986999

0 commit comments

Comments
 (0)