@@ -339,19 +339,19 @@ function LazyStack.stack(s::Symbol, args...)
339
339
ensure_named (data, name_last)
340
340
end
341
341
342
- #= ==== Zygote =====#
342
+ #= ==== Gradients =====#
343
343
344
- using ZygoteRules : @adjoint
344
+ using ChainRulesCore : ChainRulesCore, rrule, NoTangent
345
345
346
- @adjoint function stack ( vec:: AbstractArray{<:AbstractArray{<:Any,IN}} ) where {IN}
347
- stack (vec), Δ -> ([view (Δ, ntuple (_-> (:),IN)... , Tuple (I)... ) for I in eachindex (vec)],)
346
+ function ChainRulesCore . rrule ( :: typeof (stack), vec:: AbstractArray{<:AbstractArray{<:Any,IN}} ) where {IN}
347
+ stack (vec), Δ -> (NoTangent (), [view (Δ, ntuple (_-> (:),IN)... , Tuple (I)... ) for I in eachindex (vec)],)
348
348
end
349
349
350
- @adjoint function stack ( tup:: Tuple{Vararg{<:AbstractArray{<:Any,IN}}} ) where {IN}
351
- stack (tup), Δ -> (ntuple (i -> view (Δ, ntuple (_-> (:),IN)... , i), length (tup)),)
350
+ function ChainRulesCore . rrule ( :: typeof (stack), tup:: Tuple{Vararg{<:AbstractArray{<:Any,IN}}} ) where {IN}
351
+ stack (tup), Δ -> (NoTangent (), ntuple (i -> view (Δ, ntuple (_-> (:),IN)... , i), length (tup)),)
352
352
end
353
353
354
- @adjoint function stack ( gen:: Base.Generator )
354
+ function ChainRulesCore . rrule ( :: typeof (stack), gen:: Base.Generator )
355
355
stack (gen), Δ -> error (" not yet!" )
356
356
end
357
357
0 commit comments