Skip to content

Commit a59fe8a

Browse files
Merge pull request #119 from cqql/array-partition-adjoint
Define the custom adjoint on a more general ArrayPartition constructor
2 parents e9a5eba + 8c60597 commit a59fe8a

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/zygote.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i, j...)
1515
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
1616
end
1717

18-
ZygoteRules.@adjoint function ArrayPartition(x...)
18+
ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x}
1919
function ArrayPartition_adjoint(_y)
2020
y = Array(_y)
2121
starts = vcat(0,cumsum(reduce(vcat,length.(x))))
22-
ntuple(i -> reshape(y[starts[i]+1:starts[i+1]],size(x[i])),length(x))
22+
ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), nothing
2323
end
24-
ArrayPartition(x...),ArrayPartition_adjoint
24+
25+
ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
2526
end
2627

2728
ZygoteRules.@adjoint function VectorOfArray(u)

0 commit comments

Comments
 (0)