@@ -7,6 +7,7 @@ using ChainRulesCore:
77 ZeroTangent,
88 Tangent,
99 @thunk ,
10+ unthunk,
1011 canonicalize
1112using .. OperatorEnumModule: OperatorEnum
1213using .. NodeModule: AbstractExpressionNode, with_type_parameters, tree_mapreduce
@@ -52,7 +53,8 @@ struct EvalPullback{N,A,O} <: Function
5253end
5354
5455# TODO : Preferable to use the primal in the pullback somehow
55- function (e:: EvalPullback )((dY, _))
56+ function (e:: EvalPullback )((thunked_dY, _))
57+ dY = unthunk (thunked_dY)
5658 _, dX_constants_dY, complete = eval_grad_tree_array (
5759 e. tree, e. X, e. operators; variable= Val (:both )
5860 )
@@ -66,10 +68,10 @@ function (e::EvalPullback)((dY, _))
6668 dconstants_dY = @view dX_constants_dY[(nfeatures + 1 ): end , :]
6769
6870 dtree = NodeTangent (
69- e. tree, sum (j -> dconstants_dY[:, j] * dY[j], eachindex (axes (dconstants_dY, 2 )))
71+ e. tree, sum (j -> dconstants_dY[:, j] * dY[j], eachindex (dY, axes (dconstants_dY, 2 )))
7072 )
7173
72- dX = dX_dY .* reshape (dY, 1 , size (dconstants_dY, 2 ))
74+ dX = dX_dY .* reshape (dY, 1 , length (dY ))
7375
7476 return (NoTangent (), dtree, dX, NoTangent ())
7577end
0 commit comments