@@ -4,13 +4,13 @@ import RecursiveArrayTools
44
55if isdefined (Base, :get_extension )
66 using Zygote
7- using Zygote: ZygoteRules, FullArrays
7+ using Zygote: FullArrays, literal_getproperty, @adjoint
88else
99 using .. Zygote
10- using .. Zygote: ZygoteRules, FullArrays
10+ using .. Zygote: FullArrays, literal_getproperty, @adjoint
1111end
1212
13- ZygoteRules . @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Int )
13+ @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Int )
1414 function AbstractVectorOfArray_getindex_adjoint (Δ)
1515 Δ′ = [(i == j ? Δ : Fill (zero (eltype (x)), size (x)))
1616 for (x, j) in zip (VA. u, 1 : length (VA))]
@@ -19,8 +19,8 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int)
1919 VA[i], AbstractVectorOfArray_getindex_adjoint
2020end
2121
22- ZygoteRules . @adjoint function getindex (VA:: AbstractVectorOfArray ,
23- i:: Union{BitArray, AbstractArray{Bool}} )
22+ @adjoint function getindex (VA:: AbstractVectorOfArray ,
23+ i:: Union{BitArray, AbstractArray{Bool}} )
2424 function AbstractVectorOfArray_getindex_adjoint (Δ)
2525 Δ′ = [(i[j] ? Δ[j] : Fill (zero (eltype (x)), size (x)))
2626 for (x, j) in zip (VA. u, 1 : length (VA))]
@@ -29,7 +29,7 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray,
2929 VA[i], AbstractVectorOfArray_getindex_adjoint
3030end
3131
32- ZygoteRules . @adjoint function getindex (VA:: AbstractVectorOfArray , i:: AbstractArray{Int} )
32+ @adjoint function getindex (VA:: AbstractVectorOfArray , i:: AbstractArray{Int} )
3333 function AbstractVectorOfArray_getindex_adjoint (Δ)
3434 iter = 0
3535 Δ′ = [(j ∈ i ? Δ[iter += 1 ] : Fill (zero (eltype (x)), size (x)))
@@ -39,8 +39,8 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArr
3939 VA[i], AbstractVectorOfArray_getindex_adjoint
4040end
4141
42- ZygoteRules . @adjoint function getindex (VA:: AbstractVectorOfArray ,
43- i:: Union{Int, AbstractArray{Int}} )
42+ @adjoint function getindex (VA:: AbstractVectorOfArray ,
43+ i:: Union{Int, AbstractArray{Int}} )
4444 function AbstractVectorOfArray_getindex_adjoint (Δ)
4545 Δ′ = [(i[j] ? Δ[j] : Fill (zero (eltype (x)), size (x)))
4646 for (x, j) in zip (VA. u, 1 : length (VA))]
@@ -49,16 +49,16 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray,
4949 VA[i], AbstractVectorOfArray_getindex_adjoint
5050end
5151
52- ZygoteRules . @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Colon )
52+ @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Colon )
5353 function AbstractVectorOfArray_getindex_adjoint (Δ)
5454 (VectorOfArray (Δ), nothing )
5555 end
5656 VA[i], AbstractVectorOfArray_getindex_adjoint
5757end
5858
59- ZygoteRules . @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Int ,
60- j:: Union {Int, AbstractArray{Int}, CartesianIndex,
61- Colon, BitArray, AbstractArray{Bool}}. .. )
59+ @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Int ,
60+ j:: Union {Int, AbstractArray{Int}, CartesianIndex,
61+ Colon, BitArray, AbstractArray{Bool}}. .. )
6262 function AbstractVectorOfArray_getindex_adjoint (Δ)
6363 Δ′ = VectorOfArray ([zero (x) for (x, j) in zip (VA. u, 1 : length (VA))])
6464 Δ′[i, j... ] = Δ
@@ -67,12 +67,12 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int,
6767 VA[i, j... ], AbstractVectorOfArray_getindex_adjoint
6868end
6969
70- ZygoteRules . @adjoint function ArrayPartition (x:: S ,
71- :: Type{Val{copy_x}} = Val{false }) where {
72- S < :
73- Tuple,
74- copy_x
75- }
70+ @adjoint function ArrayPartition (x:: S ,
71+ :: Type{Val{copy_x}} = Val{false }) where {
72+ S < :
73+ Tuple,
74+ copy_x
75+ }
7676 function ArrayPartition_adjoint (_y)
7777 y = Array (_y)
7878 starts = vcat (0 , cumsum (reduce (vcat, length .(x))))
@@ -83,23 +83,23 @@ ZygoteRules.@adjoint function ArrayPartition(x::S,
8383 ArrayPartition (x, Val{copy_x}), ArrayPartition_adjoint
8484end
8585
86- ZygoteRules . @adjoint function VectorOfArray (u)
86+ @adjoint function VectorOfArray (u)
8787 VectorOfArray (u),
8888 y -> (VectorOfArray ([y[ntuple (x -> Colon (), ndims (y) - 1 )... , i]
8989 for i in 1 : size (y)[end ]]),)
9090end
9191
92- ZygoteRules . @adjoint function DiffEqArray (u, t)
92+ @adjoint function DiffEqArray (u, t)
9393 DiffEqArray (u, t),
9494 y -> (DiffEqArray ([y[ntuple (x -> Colon (), ndims (y) - 1 )... , i] for i in 1 : size (y)[end ]],
9595 t), nothing )
9696end
9797
98- ZygoteRules . @adjoint function ZygoteRules . literal_getproperty (A:: ArrayPartition , :: Val{:x} )
98+ @adjoint function literal_getproperty (A:: ArrayPartition , :: Val{:x} )
9999 function literal_ArrayPartition_x_adjoint (d)
100100 (ArrayPartition ((isnothing (d[i]) ? zero (A. x[i]) : d[i] for i in 1 : length (d)). .. ),)
101101 end
102102 A. x, literal_ArrayPartition_x_adjoint
103103end
104104
105- end
105+ end
0 commit comments