Skip to content

Commit cb02943

Browse files
mcabbottoxinabox
authored andcommitted
row-array weirdness, evil case
1 parent 32cab7e commit cb02943

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

src/projection.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,17 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
197197
return dz
198198
end
199199

200+
# Row vectors aren't acceptable as gradients for 1-row matrices:
201+
(project::ProjectTo{AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec) = project(reshape(vec(dx),1,:))
202+
200203
# Zero-dimensional arrays -- these have a habit of going missing,
201204
# although really Ref() is probably a better structure.
202205
function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers
203206
project.axes isa Tuple{} || throw(DimensionMismatch("array with ndims(x) == $(length(project.axes)) > 0 cannot have as gradient dx::Number"))
204207
return fill(project.element(dx))
205208
end
206209

207-
# Ref -- works like a zero-array, allowss restoration from a number:
210+
# Ref -- works like a zero-array, also allows restoration from a number:
208211
ProjectTo(x::Ref) = ProjectTo{Ref}(; x = _always_projector(x[]))
209212
(project::ProjectTo{Ref})(dx::Ref) = Ref(project.x(dx[]))
210213
(project::ProjectTo{Ref})(dx::Number) = Ref(project.x(dx))
@@ -219,27 +222,33 @@ end
219222
#####
220223

221224
# Row vectors
222-
function ProjectTo(x::LinearAlgebra.AdjointAbsVec{T}) where {T<:Number}
225+
function ProjectTo(x::LinearAlgebra.AdjointAbsVec)
223226
sub = ProjectTo(parent(x))
224227
ProjectTo{Adjoint}(; parent=sub)
225228
end
226229
# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
227230
# Transposed matrices are, like PermutedDimsArray, just a storage detail,
228231
# but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
229232
(project::ProjectTo{Adjoint})(dx::Adjoint) = adjoint(project.parent(parent(dx)))
230-
(project::ProjectTo{Adjoint})(dx::Transpose) = adjoint(conj(project.parent(parent(dx)))) # might copy twice?
233+
(project::ProjectTo{Adjoint})(dx::Transpose) = adjoint(adjoint.(project.parent(parent(dx)))) # might copy twice?
231234
function (project::ProjectTo{Adjoint})(dx::AbstractArray)
232235
size(dx,1) == 1 && size(dx,2) == length(project.parent.axes[1]) || throw(_projection_mismatch((1:1, project.parent.axes...), size(dx)))
233236
dy = project.parent(vec(dx))
234-
return adjoint(conj(dy))
237+
if eltype(dy) <: Real
238+
return adjoint(dy)
239+
else
240+
println("here")
241+
# adjoint.(dy) copies, if project.parent changed the type it copied too, ideally could fuse those
242+
return adjoint(adjoint.(dy))
243+
end
235244
end
236245

237-
function ProjectTo(x::LinearAlgebra.TransposeAbsVec{T}) where {T<:Number}
246+
function ProjectTo(x::LinearAlgebra.TransposeAbsVec)
238247
sub = ProjectTo(parent(x))
239248
ProjectTo{Transpose}(; parent=sub)
240249
end
241250
(project::ProjectTo{Transpose})(dx::Transpose) = transpose(project.parent(parent(dx)))
242-
(project::ProjectTo{Transpose})(dx::Adjoint) = transpose(conj(project.parent(parent(dx))))
251+
(project::ProjectTo{Transpose})(dx::Adjoint) = transpose(adjoint.(project.parent(parent(dx))))
243252
function (project::ProjectTo{Transpose})(dx::AbstractArray)
244253
size(dx,1) == 1 && size(dx,2) == length(project.parent.axes[1]) || throw(_projection_mismatch((1:1, project.parent.axes...), size(dx)))
245254
dy = project.parent(vec(dx))

test/projection.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using OffsetArrays, BenchmarkTools
1818
@test ProjectTo(big(1.0))(2) === 2
1919
end
2020

21-
@testset "Base: arrays" begin
21+
@testset "Base: arrays of numbers" begin
2222
pvec3 = ProjectTo([1,2,3])
2323
@test pvec3(1.0:3.0) === 1.0:3.0
2424
@test pvec3(1:3) == 1.0:3.0 # would prefer ===, map(Float64, dx) would do that, not important
@@ -35,17 +35,30 @@ using OffsetArrays, BenchmarkTools
3535
@test pmat([1 2; 3 4]') isa Matrix{ComplexF64} # broadcast type change
3636

3737
pmat2 = ProjectTo(rand(2,2)')
38-
@test pmat2([1 2; 3 4.0 + 5im]) isa Matrix # adjoint matrices are not preserved
38+
@test pmat2([1 2; 3 4.0 + 5im]) isa Matrix # adjoint matrices are not re-created
3939

40-
# arrays of arrays
40+
prow = ProjectTo([1im 2 3im])
41+
@test prow(transpose([1, 2, 3+4.0im])) == [1 2 3+4im]
42+
@test prow(transpose([1, 2, 3+4.0im])) isa Matrix # row vectors may not pass through
43+
@test prow(adjoint([1, 2, 3+5im])) == [1 2 3-5im]
44+
@test prow(adjoint([1, 2, 3])) isa Matrix
45+
end
46+
47+
@testset "Base: arrays of arrays, etc" begin
4148
pvecvec = ProjectTo([[1,2], [3,4,5]])
4249
@test pvecvec([1:2, 3:5])[1] == 1:2
4350
@test pvecvec([[1,2+3im], [4+5im,6,7]])[2] == [4,6,7]
4451
@test pvecvec(hcat([1:2, hcat(3:5)]))[2] isa Vector # reshape inner & outer
4552

53+
pvecvec2 = ProjectTo(reshape(Any[[1 2], [3 4 5]],1,2)) # a row of rows
54+
y1 = pvecvec2([[1,2], [3,4,5]]')
55+
@test y1[1] == [1 2]
56+
@test !(y1 isa Adjoint) && !(y1[1] isa Adjoint)
57+
4658
# arrays of unknown things
4759
@test ProjectTo([:x, :y])(1:2) === 1:2 # no element handling,
4860
@test ProjectTo([:x, :y])(reshape(1:2,2,1,1)) == 1:2 # but still reshapes container
61+
4962
@test ProjectTo(Any[1, 2])(1:2) == [1.0, 2.0] # projects each number.
5063
@test Tuple(ProjectTo(Any[1, 2+3im])(1:2)) === (1.0, 2.0 + 0.0im)
5164
@test ProjectTo(Any[true, false]) isa ProjectTo{NoTangent}
@@ -94,6 +107,15 @@ using OffsetArrays, BenchmarkTools
94107
@test padj_complex([4 5 6+7im]) == [4 5 6+7im]
95108
@test padj_complex(transpose([4, 5, 6+7im])) == [4 5 6+7im]
96109
@test padj_complex(adjoint([4, 5, 6+7im])) == [4 5 6-7im]
110+
111+
# evil test case
112+
xs = adjoint(Any[Any[1,2,3], Any[4+im,5-im,6+im,7-im]])
113+
pvecvec3 = ProjectTo(xs)
114+
@test pvecvec3(xs)[1] == [1 2 3]
115+
@test pvecvec3(xs)[2] isa Adjoint{ComplexF64, <:Vector}
116+
@test_broken pvecvec3(collect(xs))[1] == [1 2 3]
117+
ys = permutedims([[1 2 3+im], [4 5 6 7]])
118+
@test_broken pvecvec3(ys)[1] == [1 2 3]
97119
end
98120

99121
@testset "LinearAlgebra: structured matrices" begin
@@ -242,7 +264,7 @@ using OffsetArrays, BenchmarkTools
242264

243265
padj = ProjectTo(adjoint(rand(10^3)))
244266
@test 0 == @ballocated $padj(dx) setup=(dx=adjoint(rand(10^3)))
245-
@test 0 == @ballocated $padj(dx) setup=(dx=transpose(rand(10^3)))
267+
@test_broken 0 == @ballocated $padj(dx) setup=(dx=transpose(rand(10^3)))
246268

247269
@test 0 == @ballocated ProjectTo(x')(dx') setup=(x=rand(10^3); dx=rand(10^3))
248270

0 commit comments

Comments
 (0)