Skip to content

Commit 087fa42

Browse files
fix removal
1 parent 2d6b9cd commit 087fa42

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ end
1313
# Define a new species of projection operator for this type:
1414
ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()
1515

16+
function ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray},
17+
xs::AbstractVectorOfArray)
18+
T(xs), ȳ -> (NoTangent(), ȳ)
19+
end
20+
1621
@adjoint function getindex(VA::AbstractVectorOfArray, i::Int)
1722
function AbstractVectorOfArray_getindex_adjoint(Δ)
1823
Δ′ = [(i == j ? Δ : Fill(zero(eltype(x)), size(x)))

src/RecursiveArrayTools.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ end
2727

2828
import GPUArraysCore
2929
Base.convert(T::Type{<:GPUArraysCore.AbstractGPUArray}, VA::AbstractVectorOfArray) = T(VA)
30-
function ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray},
31-
xs::AbstractVectorOfArray)
32-
T(xs), ȳ -> (NoTangent(), ȳ)
33-
end
3430

3531
import Requires
3632
@static if !isdefined(Base, :get_extension)

0 commit comments

Comments
 (0)