Skip to content

Commit caa8a32

Browse files
committed
fix vector methods
1 parent f34f539 commit caa8a32

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

src/generic.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,18 +1169,17 @@ end
11691169
inv(A::Adjoint) = adjoint(inv(parent(A)))
11701170
inv(A::Transpose) = transpose(inv(parent(A)))
11711171

1172-
pinv(v::AbstractVector{T}, tol::Real = real(zero(T))) where {T<:Real} = _vectorpinv(transpose, v, tol)
1173-
pinv(v::AbstractVector{T}, tol::Real = real(zero(T))) where {T<:Complex} = _vectorpinv(adjoint, v, tol)
1174-
pinv(v::AbstractVector{T}, tol::Real = real(zero(T))) where {T} = _vectorpinv(adjoint, v, tol)
1175-
function _vectorpinv(dualfn::Tf, v::AbstractVector{Tv}, tol) where {Tv,Tf}
1176-
res = dualfn(similar(v, typeof(zero(Tv) / (abs2(one(Tv)) + abs2(one(Tv))))))
1172+
_pinvadjoint(v::AbstractVector{T}) where {T<:Real} = transpose(v)
1173+
_pinvadjoint(v::AbstractVector) = adjoint(v)
1174+
function pinv(v::AbstractVector{T}, tol::Real = real(zero(T))) where {T}
1175+
res = _pinvadjoint(similar(v, typeof(zero(T) / (abs2(one(T)) + abs2(one(T))))))
11771176
den = sum(abs2, v)
11781177
# as tol is the threshold relative to the maximum singular value, for a vector with
11791178
# single singular value σ=√den, σ ≦ tol*σ is equivalent to den=0 ∨ tol≥1
11801179
if iszero(den) || tol >= one(tol)
11811180
fill!(res, zero(eltype(res)))
11821181
else
1183-
res .= dualfn(v) ./ den
1182+
res .= _pinvadjoint(v) ./ den
11841183
end
11851184
return res
11861185
end
@@ -1243,7 +1242,11 @@ function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
12431242
return qr(_maybecopy(A, T), ColumnNorm()) \ B
12441243
end
12451244

1246-
(\)(a::AbstractVector, b::AbstractArray) = pinv(a) * b
1245+
function (\)(a::AbstractVector, b::AbstractArray)
1246+
den = sum(abs2, a)
1247+
goodden = den == 0 ? one(den) : den
1248+
return _pinvadjoint(a) * b / goodden
1249+
end
12471250
"""
12481251
A / B
12491252
@@ -1274,7 +1277,11 @@ function (/)(A::AbstractVecOrMat, B::AbstractVecOrMat)
12741277
end
12751278
# \(A::StridedMatrix,x::Number) = inv(A)*x Should be added at some point when the old elementwise version has been deprecated long enough
12761279
# /(x::Number,A::StridedMatrix) = x*inv(A)
1277-
/(x::Number, v::AbstractVector) = x*pinv(v)
1280+
function (/)(x::Number, v::AbstractVector)
1281+
den = sum(abs2, v)
1282+
goodden = den == 0 ? one(den) : den
1283+
return (x / goodden) * _pinvadjoint(v)
1284+
end
12781285

12791286
cond(x::Number) = iszero(x) ? Inf : 1.0
12801287
cond(x::Number, p) = cond(x)

test/generic.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -957,10 +957,21 @@ end
957957
@testset "issue 930" begin
958958
A = rand(Int, 2, 2)
959959
B = rand(Int, 2, 3)
960-
for M (A, B), T (Float32, BigFloat)
960+
C = rand(Int, 2)
961+
for T (Float32, BigFloat)
961962
v = randn(T, 2)
962-
x = @inferred M \ v
963+
x = @inferred C \ v
963964
@test eltype(x) <: T
965+
x = @inferred zero(C) \ v
966+
@test eltype(x) <: T
967+
x = @inferred T(1) / C
968+
@test eltype(x) <: T
969+
x = @inferred T(1) / zero(C)
970+
@test eltype(x) <: T
971+
for M (A, B)
972+
x = @inferred M \ v
973+
@test eltype(x) <: T
974+
end
964975
end
965976
end
966977

0 commit comments

Comments
 (0)