Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/Unitful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import Base: steprange_last, unsigned
end

import Dates
import LinearAlgebra: Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal
import LinearAlgebra: istril, istriu, norm
import LinearAlgebra: Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal, Adjoint, Transpose, AdjOrTransAbsMat
import LinearAlgebra: istril, istriu, norm, mul!, dot, /, \, inv, pinv
import Random

import ConstructionBase: constructorof
Expand Down Expand Up @@ -69,5 +69,6 @@ include("logarithm.jl")
include("complex.jl")
include("pkgdefaults.jl")
include("dates.jl")
include("linearalgebra.jl")

end
84 changes: 84 additions & 0 deletions src/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@

# Multiplication

function mul!(C::StridedVecOrMat{<:AbstractQuantity{T}},
A::StridedMatrix{<:AbstractQuantity{T}},
B::StridedVecOrMat{<:AbstractQuantity{T}},
alpha::Number, beta::Number) where {T<:Base.HWNumber}
_mul!(C, A, B, alpha, beta)
end

function mul!(C::StridedVecOrMat{<:AbstractQuantity{T}},
A::AdjOrTransAbsMat{<:AbstractQuantity{T}, <:StridedMatrix},
B::StridedVecOrMat{<:AbstractQuantity{T}},
alpha::Number, beta::Number) where {T<:Base.HWNumber}
_mul!(C, A, B, alpha, beta)
end

function _mul!(C, A, B, alpha, beta)
if unit(beta) != NoUnits
throw(DimensionError("beta", 1.0))
elseif unit(eltype(C)) != unit(eltype(A)) * unit(eltype(B)) * unit(alpha)
throw(DimensionError("A * B .* α", "C"))
end
C0 = ustrip(C)
A0 = ustrip(A)
B0 = ustrip(B)
mul!(C0, A0, B0)
_linearalgebra_count()
return C
end

function dot(A::StridedArray{<:AbstractQuantity{T}},
B::StridedArray{<:AbstractQuantity{T}}) where {T<:Base.HWNumber}
A0 = ustrip(A)
B0 = ustrip(B)
C0 = dot(A0, B0)
_linearalgebra_count()
C = C0 * unit(eltype(A)) * unit(eltype(B))
return C
end

# Division

function (\)(A::StridedMatrix{<:AbstractQuantity{T}},
B::StridedVecOrMat{<:AbstractQuantity{T}}) where {T<:Base.HWNumber}
A0 = ustrip(A)
B0 = ustrip(B)
C0 = A0 \ B0
_linearalgebra_count()
u = unit(eltype(B)) / unit(eltype(A))
Tu = typeof(one(eltype(C0)) * u)
return reinterpret(Tu, C0)
end

function (/)(A::StridedVecOrMat{<:AbstractQuantity{T}},
B::StridedVecOrMat{<:AbstractQuantity{T}}) where {T<:Base.HWNumber}
A0 = ustrip(A)
B0 = ustrip(B)
C0 = A0 / B0
_linearalgebra_count()
u = unit(eltype(A)) / unit(eltype(B))
Tu = typeof(one(eltype(C0)) * u)
return reinterpret(Tu, C0)
end

function inv(A::StridedMatrix{<:AbstractQuantity{T}}) where {T<:Base.HWNumber}
C0 = inv(ustrip(A))
_linearalgebra_count()
u = inv(unit(eltype(A)))
Tu = typeof(one(eltype(C0)) * u)
return reinterpret(Tu, C0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that these methods will return a ReinterpretArray{Quantity{... not a Matrix. There's no obvious analogue the mul! overloaded for A * B.

end

function pinv(A::StridedMatrix{<:AbstractQuantity{T}}; kw...) where {T<:Base.HWNumber}
C0 = pinv(ustrip(A); kw...)
_linearalgebra_count()
u = inv(unit(eltype(A)))
Tu = typeof(one(eltype(C0)) * u)
return reinterpret(Tu, C0)
end

# This function is re-defined during testing, to check we hit the fast path:
_linearalgebra_count() = nothing

8 changes: 6 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ true
@inline ustrip(x::Missing) = missing

"""
ustrip(x::Array{Q}) where {Q <: Quantity}
ustrip(x::Array{Q}) where {Q <: Quantity{T}}}

Strip units from an `Array` by reinterpreting to type `T`. The resulting
`Array` is a not a copy, but rather a unit-stripped view into array `x`. Because the units
are removed, information may be lost and this should be used with some care.
Expand All @@ -75,7 +76,7 @@ julia> a[1] = 3u"m"; b
2
```
"""
@inline ustrip(A::Array{Q}) where {Q <: Quantity} = reinterpret(numtype(Q), A)
@inline ustrip(A::StridedArray{Q}) where {Q <: Quantity} = reinterpret(numtype(Q), A)

@deprecate(ustrip(A::AbstractArray{T}) where {T<:Number}, ustrip.(A))

Expand All @@ -91,6 +92,9 @@ ustrip(A::Bidiagonal) = Bidiagonal(ustrip(A.dv), ustrip(A.ev), ifelse(istriu(A),
ustrip(A::Tridiagonal) = Tridiagonal(ustrip(A.dl), ustrip(A.d), ustrip(A.du))
ustrip(A::SymTridiagonal) = SymTridiagonal(ustrip(A.dv), ustrip(A.ev))

ustrip(A::Adjoint) = adjoint(ustrip(parent(A)))
ustrip(A::Transpose) = transpose(ustrip(parent(A)))

"""
unit(x::Quantity{T,D,U}) where {T,D,U}
unit(x::Type{Quantity{T,D,U}}) where {T,D,U}
Expand Down
70 changes: 70 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ using Dates:

const colon = Base.:(:)

ambig = sort(detect_ambiguities(Unitful), by = a -> [string(a[1].name), string(a[2].module)])
if length(ambig) > 0
println(stdout, "detect_ambiguities(Unitful) found $(length(ambig)) issues:")
Copy link
Contributor Author

@mcabbott mcabbott Apr 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was concerned about creating ambiguities, so made it print them before starting.

In the end, overloading mul! not * seems not to create problems. But there are (for me) 30 other ambiguities this detects (and prints out).

Now moved to #439 instead.

for i in 1:length(ambig)
println(stdout, "[",i, "]:")
println(stdout, " ", ambig[i][1])
println(stdout, " ", ambig[i][2])
end
println(stdout)
end
@testset "Construction" begin
@test isa(NoUnits, FreeUnits)
@test typeof(𝐋) === Unitful.Dimensions{(Unitful.Dimension{:Length}(1),)}
Expand Down Expand Up @@ -84,6 +94,66 @@ const colon = Base.:(:)
@test ConstructionBase.constructorof(typeof(1.0m))(2) === 2m
end

@testset "LinearAlgebra functions" begin
CNT = Ref(0)
Unitful._linearalgebra_count() = (CNT[] += 1; nothing)
@testset "> Matrix multiplication: *" begin
M = rand(3,3) .* u"m"
M_ = view(M,:,1:3)
v = rand(3) .* u"V"
v_ = view(v, 1:3)

CNT[] = 0

@test unit(first(M * M)) == u"m*m"
@test M * M == M_ * M == M * M_ == M_ * M_

@test unit(first(M * v)) == u"m*V"
@test M * v == M_ * v == M * v_ == M_ * v_

VERSION >= v"1.3" && @test CNT[] == 10

@test unit(first(v' * M)) == u"m*V"
@test v' * M == v_' * M == v_' * M == v_' * M_

VERSION >= v"1.3" && @test CNT[] == 15

@test unit(v' * v) == u"V*V"
@test v' * v == v_' * v == v_' * v == v_' * v_

VERSION >= v"1.3" && @test CNT[] == 20

# Mixed with & without units
N = rand(3,3)
w = rand(3)

CNT[] = 0

@test unit(first(M * N)) == u"m"
@test unit(first(N * M)) == u"m"

@test unit(first(M * w)) == u"m"
@test unit(first(N * v)) == u"V"

@show CNT[] # not specialised yet

end
@testset "> Matrix multiplication: mul!" begin
A = rand(3,3) .* u"m"
B = rand(3,3) .* u"m"
C = fill(zero(eltype(A*B)), 3, 3)
CNT[] = 0

mul!(C, A, B)
if VERSION >= v"1.3" # the 5-arm mul! exists
mul!(C, A, B, true, true)
mul!(C, A, B, 3, 7) # not specialised yet

@show CNT[]
end
end
end

@testset "Types" begin
@test Base.complex(Quantity{Float64,NoDims,NoUnits}) ==
Quantity{Complex{Float64},NoDims,NoUnits}
Expand Down