Skip to content

Commit 1ead5e2

Browse files
mcabbottoxinabox
authored andcommitted
fixes
1 parent 134c9be commit 1ead5e2

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

src/projection.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,24 +229,22 @@ end
229229
# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
230230
# Transposed matrices are, like PermutedDimsArray, just a storage detail,
231231
# but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
232-
(project::ProjectTo{Adjoint})(dx::Adjoint) = adjoint(project.parent(parent(dx)))
233-
(project::ProjectTo{Adjoint})(dx::Transpose) = adjoint(adjoint.(project.parent(parent(dx)))) # might copy twice?
232+
(project::ProjectTo{Adjoint})(dx::LinearAlgebra.AdjOrTransAbsVec) = adjoint(project.parent(adjoint(dx)))
234233
function (project::ProjectTo{Adjoint})(dx::AbstractArray)
235234
size(dx,1) == 1 && size(dx,2) == length(project.parent.axes[1]) || throw(_projection_mismatch((1:1, project.parent.axes...), size(dx)))
236-
dy = project.parent(adjoint(dx))
237-
return adjoint(dy)
235+
dy = eltype(dx) <: Real ? vec(dx) : adjoint(dx)
236+
return adjoint(project.parent(dy))
238237
end
239238

240239
function ProjectTo(x::LinearAlgebra.TransposeAbsVec)
241240
sub = ProjectTo(parent(x))
242241
ProjectTo{Transpose}(; parent=sub)
243242
end
244-
(project::ProjectTo{Transpose})(dx::Transpose) = transpose(project.parent(parent(dx)))
245-
(project::ProjectTo{Transpose})(dx::Adjoint) = transpose(adjoint.(project.parent(parent(dx))))
243+
(project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec) = transpose(project.parent(transpose(dx)))
246244
function (project::ProjectTo{Transpose})(dx::AbstractArray)
247245
size(dx,1) == 1 && size(dx,2) == length(project.parent.axes[1]) || throw(_projection_mismatch((1:1, project.parent.axes...), size(dx)))
248-
dy = project.parent(transpose(dx))
249-
return transpose(dy)
246+
dy = eltype(dx) <: Number ? vec(dx) : transpose(dx)
247+
return transpose(project.parent(dy))
250248
end
251249

252250
# Diagonal

test/projection.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ using OffsetArrays, BenchmarkTools
9898
adjT = typeof(adj([1,2,3.0]))
9999
@test padj(transpose(1:3)) isa adjT
100100
@test padj([4 5 6+7im]) isa adjT
101-
@test_broken padj([4.0 5.0 6.0]) isa adjT
101+
@test padj([4.0 5.0 6.0]) isa adjT
102102

103103
@test_throws DimensionMismatch padj([1,2,3])
104104
@test_throws DimensionMismatch padj([1 2 3]')
@@ -258,7 +258,7 @@ using OffsetArrays, BenchmarkTools
258258
@test repr(ProjectTo(1.1)) == "ProjectTo{Float64}()"
259259
@test occursin("ProjectTo{AbstractArray}(element", repr(ProjectTo([1,2,3])))
260260
str = repr(ProjectTo([1,2,3]'))
261-
@test_broken eval(Meta.parse(str))(ones(1,3)) isa Adjoint{Float64, Vector{Float64}}
261+
@test eval(Meta.parse(str))(ones(1,3)) isa Adjoint{Float64, Vector{Float64}}
262262
end
263263

264264
VERSION > v"1.1" && @testset "allocation tests" begin
@@ -272,7 +272,7 @@ using OffsetArrays, BenchmarkTools
272272

273273
padj = ProjectTo(adjoint(rand(10^3)))
274274
@test 0 == @ballocated $padj(dx) setup=(dx=adjoint(rand(10^3)))
275-
@test_broken 0 == @ballocated $padj(dx) setup=(dx=transpose(rand(10^3)))
275+
@test 0 == @ballocated $padj(dx) setup=(dx=transpose(rand(10^3)))
276276

277277
@test 0 == @ballocated ProjectTo(x')(dx') setup=(x=rand(10^3); dx=rand(10^3))
278278

0 commit comments

Comments
 (0)