Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Unitful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,6 @@ include("logarithm.jl")
include("complex.jl")
include("pkgdefaults.jl")
include("dates.jl")
include("linearalgebra.jl")

end
42 changes: 42 additions & 0 deletions src/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using LinearAlgebra

# This function is re-defined during testing, to check we hit the fast path:
linearalgebra_count() = nothing

function LinearAlgebra.mul!(C::StridedVecOrMat{<:AbstractQuantity{T}},
A::StridedMatrix{<:AbstractQuantity{T}},
B::StridedVecOrMat{<:AbstractQuantity{T}},
alpha::Bool, beta::Bool) where {T<:Base.HWNumber}
Copy link
Contributor Author

@mcabbott mcabbott Apr 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case α, β::Bool is what A * B produces. Allowing other pure numbers ought to be fine. Allowing these to have units I haven't looked into.

I have restricted to A, B, C having the same eltype, as I think this is what BLAS handles. Julia sometimes copies a matrix to promote e.g. Float64 * Float32, but I don't recall at what stage.

5-arg mul! doesn't exist on 1.0, but that's OK, this will just never be called, so you'll get the slow but correct fallback.

# This is exactly how A * B creates C = similar(B, T, ...)
eltype(C) == Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) || error("bad eltypes")
C0 = ustrip(C)
A0 = ustrip(A)
B0 = ustrip(B)
mul!(C0, A0, B0)
linearalgebra_count()
return C
end

function LinearAlgebra.mul!(C::StridedVecOrMat{<:AbstractQuantity{T}},
A::LinearAlgebra.AdjOrTransAbsMat{<:AbstractQuantity{T}, <:StridedMatrix},
B::StridedVecOrMat{<:AbstractQuantity{T}},
alpha::Bool, beta::Bool) where {T<:Base.HWNumber}

eltype(C) == Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) || error("bad eltypes")
C0 = ustrip(C)
A0 = A isa Adjoint ? adjoint(ustrip(parent(A))) : transpose(ustrip(parent(A)))
B0 = ustrip(B)
mul!(C0, A0, B0)
linearalgebra_count()
return C
end

function LinearAlgebra.dot(A::StridedArray{<:AbstractQuantity{T}},
B::StridedArray{<:AbstractQuantity{T}}) where {T<:Base.HWNumber}
A0 = ustrip(A)
B0 = ustrip(B)
C0 = dot(A0, B0)
linearalgebra_count()
C = C0 * oneunit(eltype(A)) * oneunit(eltype(B)) # surely there is an official way
return C
end
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ julia> a[1] = 3u"m"; b
2
```
"""
@inline ustrip(A::Array{Q}) where {Q <: Quantity} = reinterpret(numtype(Q), A)
@inline ustrip(A::StridedArray{Q}) where {Q <: Quantity} = reinterpret(numtype(Q), A)

@deprecate(ustrip(A::AbstractArray{T}) where {T<:Number}, ustrip.(A))

Expand Down
44 changes: 44 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ using Dates:

const colon = Base.:(:)

ambig = sort(detect_ambiguities(Unitful), by = a -> [string(a[1].name), string(a[2].module)])
if length(ambig) > 0
println(stdout, "detect_ambiguities(Unitful) found $(length(ambig)) issues:")
Copy link
Contributor Author

@mcabbott mcabbott Apr 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was concerned about creating ambiguities, so made it print them before starting.

In the end, overloading mul! not * seems not to create problems. But there are (for me) 30 other ambiguities this detects (and prints out).

Now moved to #439 instead.

for i in 1:length(ambig)
println(stdout, "[",i, "]:")
println(stdout, " ", ambig[i][1])
println(stdout, " ", ambig[i][2])
end
println(stdout)
end
@testset "Construction" begin
@test isa(NoUnits, FreeUnits)
@test typeof(𝐋) === Unitful.Dimensions{(Unitful.Dimension{:Length}(1),)}
Expand Down Expand Up @@ -84,6 +94,40 @@ const colon = Base.:(:)
@test ConstructionBase.constructorof(typeof(1.0m))(2) === 2m
end

@testset "LinearAlgebra functions" begin
CNT = Ref(0)
Unitful.linearalgebra_count() = (CNT[] += 1; nothing)
@testset "> Matrix multiplication: *" begin
M = rand(3,3) .* u"m"
M_ = view(M,:,1:3)
v = rand(3) .* u"V"
v_ = view(v, 1:3)

CNT[] = 0

@test unit(first(M * M)) == u"m*m"
@test M * M == M_ * M == M * M_ == M_ * M_

@test unit(first(M * v)) == u"m*V"
@test M * v == M_ * v == M * v_ == M_ * v_

@test CNT[] == 10

@test unit(first(v' * M)) == u"m*V"
@test v' * M == v_' * M == v_' * M == v_' * M_

@test CNT[] == 15

@test unit(v' * v) == u"V*V"
@test v' * v == v_' * v == v_' * v == v_' * v_

@test CNT[] == 20
end
@testset "> Matrix multiplication: mul!" begin

end
end

@testset "Types" begin
@test Base.complex(Quantity{Float64,NoDims,NoUnits}) ==
Quantity{Complex{Float64},NoDims,NoUnits}
Expand Down