|
31 | 31 |
|
32 | 32 | @adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int}) |
33 | 33 | function AbstractVectorOfArray_getindex_adjoint(Δ) |
| 34 | + @show "in hete at vecint" |
34 | 35 | iter = 0 |
35 | 36 | Δ′ = [(j ∈ i ? Δ[iter += 1] : FillArrays.Fill(zero(eltype(x)), size(x))) |
36 | 37 | for (x, j) in zip(VA.u, 1:length(VA))] |
|
77 | 78 | ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint |
78 | 79 | end |
79 | 80 |
|
80 | | -@adjoint function VectorOfArray(u) |
81 | | - VectorOfArray(u), |
82 | | - y -> begin |
83 | | - y isa Ref && (y = VectorOfArray(y[].u)) |
84 | | - (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] |
85 | | - for i in 1:size(y)[end]]),) |
86 | | - end |
87 | | -end |
| 81 | +# @adjoint function VectorOfArray(u) |
| 82 | +# VectorOfArray(u), |
| 83 | +# y -> begin |
| 84 | +# y isa Ref && (y = VectorOfArray(y[].u)) |
| 85 | +# (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] |
| 86 | +# for i in 1:size(y)[end]]),) |
| 87 | +# end |
| 88 | +# end |
88 | 89 |
|
89 | 90 | @adjoint function Base.copy(u::VectorOfArray) |
90 | 91 | copy(u), |
|
145 | 146 |
|
146 | 147 | function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{ |
147 | 148 | AbstractArray, AbstractVectorOfArray}) |
148 | | - arr = reshape(x, p.sz) |
149 | | - return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) |
| 149 | + if eltype(x) <: Number |
| 150 | + arr = reshape(x, p.sz) |
| 151 | + return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) |
| 152 | + elseif eltype(x) <: AbstractArray |
| 153 | + return VectorOfArray(x) |
| 154 | + end |
150 | 155 | end |
151 | 156 |
|
152 | 157 | @adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray, |
|
271 | 276 | ȳ -> (nothing, Zygote._project(x, ȳ)) |
272 | 277 |
|
273 | 278 | function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄) |
| 279 | + @show x̄ |
274 | 280 | N = ndims(x̄) |
275 | 281 | if length(x) == length(x̄) |
276 | 282 | Zygote._project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors |
|
0 commit comments