Skip to content

Commit e737e07

Browse files
Merge pull request #97 from SciML/ap_adjoints
add missing ArrayPartition adjoint
2 parents 84f3051 + 7b85e02 commit e737e07

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

src/array_partition.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ Base.zero(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(zero.(A.x))
5151
# ignore dims since array partitions are vectors
5252
Base.zero(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = zero(A)
5353

54+
## Array
55+
56+
Base.Array(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A <: AbstractVector{<:ArrayPartition}} = reduce(hcat,Array.(VA.u))
57+
5458
## ones
5559

5660
# special to work with units

src/zygote.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,12 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i, j...)
1414
end
1515
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
1616
end
17+
18+
ZygoteRules.@adjoint function ArrayPartition(x...)
19+
function ArrayPartition_adjoint(_y)
20+
y = Array(_y)
21+
starts = vcat(0,cumsum(reduce(vcat,length.(x))))
22+
ntuple(i -> reshape(y[starts[i]+1:starts[i+1]],size(x[i])),length(x))
23+
end
24+
ArrayPartition(x...),ArrayPartition_adjoint
25+
end

0 commit comments

Comments
 (0)