Skip to content

Commit 024704e

Browse files
handle permutedims rrule with ZeroTangent (#683)
1 parent 1597bcc commit 024704e

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/rulesets/Base/array.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,18 +143,21 @@ end
143143
function rrule(::typeof(permutedims), x::AbstractVector)
144144
project = ProjectTo(x)
145145
permutedims_pullback_1(dy) = (NoTangent(), project(permutedims(unthunk(dy))))
146+
permutedims_pullback_1(::ZeroTangent) = (NoTangent(), ZeroTangent())
146147
return permutedims(x), permutedims_pullback_1
147148
end
148149

149150
function rrule(::typeof(permutedims), x::AbstractArray, perm)
150151
pr = ProjectTo(x) # projection restores e.g. transpose([1,2,3])
151152
permutedims_back_2(dy) = (NoTangent(), pr(permutedims(unthunk(dy), invperm(perm))), NoTangent())
153+
permutedims_back_2(::ZeroTangent) = (NoTangent(), ZeroTangent(), NoTangent())
152154
return permutedims(x, perm), permutedims_back_2
153155
end
154156

155157
function rrule(::Type{<:PermutedDimsArray}, x::AbstractArray, perm)
156158
pr = ProjectTo(x)
157159
permutedims_back_3(dy) = (NoTangent(), pr(permutedims(unthunk(dy), invperm(perm))), NoTangent())
160+
permutedims_back_3(::ZeroTangent) = (NoTangent(), ZeroTangent(), NoTangent())
158161
return PermutedDimsArray(x, perm), permutedims_back_3
159162
end
160163

0 commit comments

Comments
 (0)