|
1 | | -function ChainRulesCore.rrule(::typeof(getindex), VA::AbstractVectorOfArray, |
2 | | - i::Union{Int, AbstractArray{Int}, CartesianIndex, Colon, |
3 | | - BitArray, AbstractArray{Bool}}) |
4 | | - function AbstractVectorOfArray_getindex_adjoint(Δ) |
5 | | - Δ′ = [(i == j ? Δ : zero(x)) for (x, j) in zip(VA.u, 1:length(VA))] |
6 | | - (NoTangent(), VectorOfArray(Δ′), NoTangent()) |
7 | | - end |
8 | | - VA[i], AbstractVectorOfArray_getindex_adjoint |
9 | | -end |
10 | | - |
11 | | -function ChainRulesCore.rrule(::typeof(getindex), VA::AbstractVectorOfArray, |
12 | | - indices::Union{Int, AbstractArray{Int}, CartesianIndex, Colon, |
13 | | - BitArray, AbstractArray{Bool}}...) |
14 | | - function AbstractVectorOfArray_getindex_adjoint(Δ) |
15 | | - Δ′ = zero(VA) |
16 | | - Δ′[indices...] = Δ |
17 | | - (NoTangent(), VectorOfArray(Δ′), map(_ -> NoTangent(), indices)...) |
18 | | - end |
19 | | - VA[indices...], AbstractVectorOfArray_getindex_adjoint |
20 | | -end |
21 | | - |
22 | | -function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, |
23 | | - ::Type{Val{copy_x}} = Val{false}) where {S <: Tuple, copy_x} |
24 | | - function ArrayPartition_adjoint(_y) |
25 | | - y = Array(_y) |
26 | | - starts = vcat(0, cumsum(reduce(vcat, length.(x)))) |
27 | | - NoTangent(), |
28 | | - ntuple(i -> reshape(y[(starts[i] + 1):starts[i + 1]], size(x[i])), length(x)), |
29 | | - NoTangent() |
30 | | - end |
31 | | - |
32 | | - ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint |
33 | | -end |
34 | | - |
35 | | -function ChainRulesCore.rrule(::Type{<:VectorOfArray}, u) |
36 | | - VectorOfArray(u), |
37 | | - y -> (NoTangent(), |
38 | | - [y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]]) |
39 | | -end |
40 | | - |
41 | | -function ChainRulesCore.rrule(::Type{<:DiffEqArray}, u, t) |
42 | | - DiffEqArray(u, t), |
43 | | - y -> (NoTangent(), |
44 | | - [y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]], |
45 | | - NoTangent()) |
46 | | -end |
47 | | - |
48 | | -function ChainRulesCore.rrule(::typeof(getproperty), A::ArrayPartition, s::Symbol) |
49 | | - if s !== :x |
50 | | - error("$s is not a field of ArrayPartition") |
51 | | - end |
52 | | - function literal_ArrayPartition_x_adjoint(d) |
53 | | - (NoTangent(), |
54 | | - ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...)) |
55 | | - end |
56 | | - A.x, literal_ArrayPartition_x_adjoint |
57 | | -end |
58 | | - |
59 | 1 | # Define a new species of projection operator for this type: |
60 | 2 | ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}() |
61 | | - |
62 | | -# Gradient from iteration will be e.g. Vector{Vector}, this makes it another AbstractMatrix |
63 | | -#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractVector{<:AbstractArray}) = VectorOfArray(dx) |
64 | | -# Gradient from broadcasting will be another AbstractArray |
65 | | -#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx |
66 | | - |
67 | | -# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint` |
68 | | -# definition first, and finds its own before finding those. |
0 commit comments