Skip to content

Commit f45e0ef

Browse files
committed
fix #685
1 parent 2aad117 commit f45e0ef

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

src/projection.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ ProjectTo(::Any) = identity
128128
ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pass makes this one projector,
129129
(::ProjectTo{NoTangent})(dx) = NoTangent() # but this is the projection only for nonzero gradients,
130130
(::ProjectTo{NoTangent})(dx::AbstractZero) = dx # and this one solves an ambiguity.
131+
(::ProjectTo{NoTangent})(::AbstractThunk) = NoTangent() # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/685
132+
(::ProjectTo{NoTangent})(::Thunk) = NoTangent()
131133

132134
# Also, any explicit construction with fields, where all fields project to zero, itself
133135
# projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]).
@@ -277,7 +279,7 @@ end
277279
# but as `Ref{Any}((x=val,))`. Here we use a Tangent, there is at present no mutable version, but see
278280
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/105
279281
function ProjectTo(x::Ref)
280-
sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
282+
sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
281283
return ProjectTo{Tangent{typeof(x)}}(; x=sub)
282284
end
283285
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Tangent) = project(Ref(first(backing(dx))))

test/projection.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ struct NoSuperType end
8080

8181
prow = ProjectTo([1im 2 3im])
8282
@test prow(transpose([1, 2, 3 + 4.0im])) == [1 2 3 + 4im]
83-
@test prow(transpose([1, 2, 3 + 4.0im])) isa Matrix # row vectors may not pass through
83+
@test prow(transpose([1, 2, 3 + 4.0im])) isa Matrix # row vectors may not pass through
8484
@test prow(adjoint([1, 2, 3 + 5im])) == [1 2 3 - 5im]
8585
@test prow(adjoint([1, 2, 3])) isa Matrix
8686

@@ -145,7 +145,7 @@ struct NoSuperType end
145145

146146
@test ProjectTo(Ref(true)) isa ProjectTo{NoTangent}
147147
@test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}
148-
148+
149149
@test ProjectTo(Ref(1.0))(Ref(NoTangent())) === NoTangent() # collapse all-zero
150150
end
151151

@@ -376,7 +376,7 @@ struct NoSuperType end
376376

377377
pvec3 = ProjectTo([1, 2, 3])
378378
@test axes(pvec3(OffsetArray(rand(3), 0:2))) == (1:3,)
379-
@test pvec3(OffsetArray(rand(3), 0:2)) isa Vector # relies on axes === axes test
379+
@test pvec3(OffsetArray(rand(3), 0:2)) isa Vector # relies on axes === axes test
380380
@test pvec3(OffsetArray(rand(3,1), 0:2, 0:0)) isa Vector
381381
end
382382

@@ -463,4 +463,8 @@ struct NoSuperType end
463463
psymm = ProjectTo(Symmetric(rand(10^3, 10^3)))
464464
@test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64
465465
end
466+
467+
@testset "#685" begin
468+
@test ProjectTo(BitArray([0]))(@thunk[1.0]) == NoTangent()
469+
end
466470
end

0 commit comments

Comments
 (0)