diff --git a/Project.toml b/Project.toml index 3c24580c..47ac474e 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ Optim = "0.19, 1" PrecompileTools = "1" Reexport = "1" SymbolicUtils = "0.19, ^1.0.5, 2, 3" -Zygote = "0.6" +Zygote = "0.6, 0.7" julia = "1.10" [extras] diff --git a/src/ChainRules.jl b/src/ChainRules.jl index d8273f19..f2dd4e36 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -7,6 +7,7 @@ using ChainRulesCore: ZeroTangent, Tangent, @thunk, + unthunk, canonicalize using ..OperatorEnumModule: OperatorEnum using ..NodeModule: AbstractExpressionNode, with_type_parameters, tree_mapreduce @@ -52,7 +53,8 @@ struct EvalPullback{N,A,O} <: Function end # TODO: Preferable to use the primal in the pullback somehow -function (e::EvalPullback)((dY, _)) +function (e::EvalPullback)((thunked_dY, _)) + dY = unthunk(thunked_dY) _, dX_constants_dY, complete = eval_grad_tree_array( e.tree, e.X, e.operators; variable=Val(:both) )