@@ -143,18 +143,21 @@ end
143
143
function rrule (:: typeof (permutedims), x:: AbstractVector )
144
144
project = ProjectTo (x)
145
145
permutedims_pullback_1 (dy) = (NoTangent (), project (permutedims (unthunk (dy))))
146
+ permutedims_pullback_1 (:: ZeroTangent ) = (NoTangent (), ZeroTangent ())
146
147
return permutedims (x), permutedims_pullback_1
147
148
end
148
149
149
150
function rrule (:: typeof (permutedims), x:: AbstractArray , perm)
150
151
pr = ProjectTo (x) # projection restores e.g. transpose([1,2,3])
151
152
permutedims_back_2 (dy) = (NoTangent (), pr (permutedims (unthunk (dy), invperm (perm))), NoTangent ())
153
+ permutedims_back_2 (:: ZeroTangent ) = (NoTangent (), ZeroTangent (), NoTangent ())
152
154
return permutedims (x, perm), permutedims_back_2
153
155
end
154
156
155
157
function rrule (:: Type{<:PermutedDimsArray} , x:: AbstractArray , perm)
156
158
pr = ProjectTo (x)
157
159
permutedims_back_3 (dy) = (NoTangent (), pr (permutedims (unthunk (dy), invperm (perm))), NoTangent ())
160
+ permutedims_back_3 (:: ZeroTangent ) = (NoTangent (), ZeroTangent (), NoTangent ())
158
161
return PermutedDimsArray (x, perm), permutedims_back_3
159
162
end
160
163
0 commit comments