Skip to content

Commit 2d8dcdf

Browse files
authored
Internal type promotion fixes (#112)
1 parent e5a383f commit 2d8dcdf

11 files changed

+50
-40
lines changed

src/LinearMaps.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,20 @@ Base.ndims(::LinearMap) = 2
4848
Base.size(A::LinearMap, n) = (n==1 || n==2 ? size(A)[n] : error("LinearMap objects have only 2 dimensions"))
4949
Base.length(A::LinearMap) = size(A)[1] * size(A)[2]
5050

51+
# check dimension consistency for multiplication A*B
52+
_iscompatible((A, B)) = size(A, 2) == size(B, 1)
53+
function check_dim_mul(A, B)
54+
_iscompatible((A, B)) ||
55+
throw(DimensionMismatch("second dimension of left factor, $(size(A, 2)), " *
56+
"does not match first dimension of right factor, $(size(B, 1))"))
57+
return nothing
58+
end
5159
# check dimension consistency for multiplication C = A*B
5260
function check_dim_mul(C, A, B)
53-
# @info "checked vector dimensions" # uncomment for testing
5461
mA, nA = size(A) # A always has two dimensions
5562
mB, nB = size(B, 1), size(B, 2)
56-
(mB == nA) ||
57-
throw(DimensionMismatch("left factor has dimensions ($mA,$nA), right factor has dimensions ($mB,$nB)"))
58-
(size(C, 1) != mA || size(C, 2) != nB) &&
59-
throw(DimensionMismatch("result has dimensions $(size(C)), needs ($mA,$nB)"))
63+
(mB == nA && size(C, 1) == mA && size(C, 2) == nB) ||
64+
throw(DimensionMismatch("A has size ($mA,$nA), B has size ($mB,$nB), C has size $(size(C))"))
6065
return nothing
6166
end
6267

@@ -93,10 +98,8 @@ julia> A*x
9398
```
9499
"""
95100
function Base.:(*)(A::LinearMap, x::AbstractVector)
96-
m, n = size(A)
97-
n == length(x) || throw(DimensionMismatch("linear map has dimensions ($m,$n), " *
98-
"vector has length $(length(x))"))
99-
return mul!(similar(x, promote_type(eltype(A), eltype(x)), m), A, x)
101+
check_dim_mul(A, x)
102+
return mul!(similar(x, promote_type(eltype(A), eltype(x)), size(A, 1)), A, x)
100103
end
101104

102105
"""

src/blockmap.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ struct BlockMap{T,As<:Tuple{Vararg{LinearMap}},Rs<:Tuple{Vararg{Int}},Rranges<:T
44
rowranges::Rranges
55
colranges::Cranges
66
function BlockMap{T,R,S}(maps::R, rows::S) where {T, R<:Tuple{Vararg{LinearMap}}, S<:Tuple{Vararg{Int}}}
7-
for A in maps
8-
promote_type(T, eltype(A)) == T || throw(InexactError())
7+
for n in eachindex(maps)
8+
A = maps[n]
9+
@assert promote_type(T, eltype(A)) == T "eltype $(eltype(A)) cannot be promoted to $T in BlockMap constructor"
910
end
1011
rowranges, colranges = rowcolranges(maps, rows)
1112
return new{T,R,S,typeof(rowranges),typeof(colranges)}(maps, rows, rowranges, colranges)
@@ -391,8 +392,9 @@ struct BlockDiagonalMap{T,As<:Tuple{Vararg{LinearMap}},Ranges<:Tuple{Vararg{Unit
391392
rowranges::Ranges
392393
colranges::Ranges
393394
function BlockDiagonalMap{T,As}(maps::As) where {T, As<:Tuple{Vararg{LinearMap}}}
394-
for A in maps
395-
promote_type(T, eltype(A)) == T || throw(InexactError())
395+
for n in eachindex(maps)
396+
A = maps[n]
397+
@assert promote_type(T, eltype(A)) == T "eltype $(eltype(A)) cannot be promoted to $T in BlockDiagonalMap constructor"
396398
end
397399
# row ranges
398400
inds = vcat(1, size.(maps, 1)...)

src/composition.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
# helper function
2-
check_dim_mul(A, B) = size(A, 2) == size(B, 1)
3-
41
struct CompositeMap{T, As<:Tuple{Vararg{LinearMap}}} <: LinearMap{T}
52
maps::As # stored in order of application to vector
63
function CompositeMap{T, As}(maps::As) where {T, As}
74
N = length(maps)
85
for n in 2:N
9-
check_dim_mul(maps[n], maps[n-1]) || throw(DimensionMismatch("CompositeMap"))
6+
check_dim_mul(maps[n], maps[n-1])
107
end
11-
for n in 1:N
12-
promote_type(T, eltype(maps[n])) == T || throw(InexactError())
8+
for n in eachindex(maps)
9+
A = maps[n]
10+
@assert promote_type(T, eltype(A)) == T "eltype $(eltype(A)) cannot be promoted to $T in CompositeMap constructor"
1311
end
1412
new{T, As}(maps)
1513
end

src/kronecker.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
struct KroneckerMap{T, As<:Tuple{Vararg{LinearMap}}} <: LinearMap{T}
22
maps::As
33
function KroneckerMap{T, As}(maps::As) where {T, As}
4-
for A in maps
5-
promote_type(T, eltype(A)) == T || throw(InexactError())
4+
for n in eachindex(maps)
5+
A = maps[n]
6+
@assert promote_type(T, eltype(A)) == T "eltype $(eltype(A)) cannot be promoted to $T in KroneckerMap constructor"
67
end
78
return new{T,As}(maps)
89
end
@@ -151,7 +152,7 @@ end
151152
function _unsafe_mul!(y::AbstractVecOrMat, L::CompositeMap{<:Any,<:Tuple{KroneckerMap,KroneckerMap}}, x::AbstractVector)
152153
require_one_based_indexing(y)
153154
B, A = L.maps
154-
if length(A.maps) == length(B.maps) && all(M -> check_dim_mul(M[1], M[2]), zip(A.maps, B.maps))
155+
if length(A.maps) == length(B.maps) && all(_iscompatible, zip(A.maps, B.maps))
155156
_unsafe_mul!(y, kron(map(*, A.maps, B.maps)...), x)
156157
else
157158
_unsafe_mul!(y, LinearMap(A)*B, x)
@@ -166,7 +167,7 @@ function _unsafe_mul!(y::AbstractVecOrMat, L::CompositeMap{T,<:Tuple{Vararg{Kron
166167
Bs = map(AB -> AB.maps[2], L.maps)
167168
As1, As2 = Base.front(As), Base.tail(As)
168169
Bs1, Bs2 = Base.front(Bs), Base.tail(Bs)
169-
apply = all(A -> check_dim_mul(A...), zip(As1, As2)) && all(A -> check_dim_mul(A...), zip(Bs1, Bs2))
170+
apply = all(_iscompatible, zip(As1, As2)) && all(_iscompatible, zip(Bs1, Bs2))
170171
if apply
171172
_unsafe_mul!(y, kron(prod(As), prod(Bs)), x)
172173
else
@@ -181,9 +182,10 @@ end
181182
struct KroneckerSumMap{T, As<:Tuple{LinearMap,LinearMap}} <: LinearMap{T}
182183
maps::As
183184
function KroneckerSumMap{T, As}(maps::As) where {T, As}
184-
for A in maps
185+
for n in eachindex(maps)
186+
A = maps[n]
185187
size(A, 1) == size(A, 2) || throw(ArgumentError("operators need to be square in Kronecker sums"))
186-
promote_type(T, eltype(A)) == T || throw(InexactError())
188+
@assert promote_type(T, eltype(A)) == T "eltype $(eltype(A)) cannot be promoted to $T in KroneckerSumMap constructor"
187189
end
188190
return new{T,As}(maps)
189191
end

src/linearcombination.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ struct LinearCombination{T, As<:Tuple{Vararg{LinearMap}}} <: LinearMap{T}
33
function LinearCombination{T, As}(maps::As) where {T, As}
44
N = length(maps)
55
sz = size(maps[1])
6-
for n in 1:N
7-
size(maps[n]) == sz || throw(DimensionMismatch("LinearCombination"))
8-
promote_type(T, eltype(maps[n])) == T || throw(InexactError())
6+
for n in eachindex(maps)
7+
A = maps[n]
8+
size(A) == sz || throw(DimensionMismatch("LinearCombination"))
9+
@assert promote_type(T, eltype(A)) == T "eltype $(eltype(A)) cannot be promoted to $T in LinearCombination constructor"
910
end
1011
new{T, As}(maps)
1112
end

src/scaledmap.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
struct ScaledMap{T, S<:RealOrComplex, A<:LinearMap} <: LinearMap{T}
22
λ::S
33
lmap::A
4-
function ScaledMap{T,S,A}::S, lmap::A) where {T, S <: RealOrComplex, A <: LinearMap}
5-
Base.promote_op(*, S, eltype(lmap)) == T || throw(InexactError())
4+
function ScaledMap{T}::S, lmap::A) where {T, S <: RealOrComplex, A <: LinearMap}
5+
@assert Base.promote_op(*, S, eltype(lmap)) == T "target type $T cannot hold products of $S and $(eltype(lmap)) objects"
66
new{T,S,A}(λ, lmap)
77
end
88
end
99

1010
# constructor
11-
ScaledMap{T}::S, lmap::A) where {T,S<:RealOrComplex,A<:LinearMap} =
12-
ScaledMap{Base.promote_op(*, S, eltype(lmap)),S,A}(λ, lmap)
11+
ScaledMap::S, lmap::A) where {S<:RealOrComplex,A<:LinearMap} =
12+
ScaledMap{Base.promote_op(*, S, eltype(lmap))}(λ, lmap)
1313

1414
# show
1515
function Base.show(io::IO, A::ScaledMap{T}) where {T}
@@ -32,14 +32,8 @@ Base.:(==)(A::ScaledMap, B::ScaledMap) =
3232
(eltype(A) == eltype(B) && A.lmap == B.lmap) && A.λ == B.λ
3333

3434
# scalar multiplication and division
35-
function Base.:(*)(α::RealOrComplex, A::LinearMap)
36-
T = Base.promote_op(*, typeof(α), eltype(A))
37-
return ScaledMap{T}(α, A)
38-
end
39-
function Base.:(*)(A::LinearMap, α::RealOrComplex)
40-
T = Base.promote_op(*, typeof(α), eltype(A))
41-
return ScaledMap{T}(α, A)
42-
end
35+
Base.:(*)(α::RealOrComplex, A::LinearMap) = ScaledMap(α, A)
36+
Base.:(*)(A::LinearMap, α::RealOrComplex) = ScaledMap(α, A)
4337

4438
Base.:(*)(α::Number, A::ScaledMap) =* A.λ) * A.lmap
4539
Base.:(*)(A::ScaledMap, α::Number) = A.lmap * (A.λ * α)

test/blockmap.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays, BenchmarkTools, Interactive
99
L = @inferred hcat(LinearMap(A11), LinearMap(A12))
1010
@test @inferred(LinearMaps.MulStyle(L)) === matrixstyle
1111
@test L isa LinearMaps.BlockMap{elty}
12+
if elty <: Complex
13+
@test_throws AssertionError LinearMaps.BlockMap{Float64}((LinearMap(A11), LinearMap(A12)), (2,))
14+
end
1215
A = [A11 A12]
1316
x = rand(10+n2)
1417
@test size(L) == size(A)
@@ -182,6 +185,9 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays, BenchmarkTools, Interactive
182185
M3 = randn(elty, m, n+3); L3 = LinearMap(M3)
183186

184187
# Md = diag(M1, M2, M3, M2, M1) # unsupported so use sparse:
188+
if elty <: Complex
189+
@test_throws AssertionError LinearMaps.BlockDiagonalMap{Float64}((L1, L2, L3, L2, L1))
190+
end
185191
Md = Matrix(blockdiag(sparse.((M1, M2, M3, M2, M1))...))
186192
@test (@which blockdiag(sparse.((M1, M2, M3, M2, M1))...)).module != LinearMaps
187193
@test (@which cat(M1, M2, M3, M2, M1; dims=(1,2))).module != LinearMaps

test/composition.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays
66
FCM = LinearMaps.CompositeMap{ComplexF64}((FC,))
77
L = LowerTriangular(ones(10,10))
88
@test_throws DimensionMismatch F * LinearMap(rand(2,2))
9+
@test_throws AssertionError LinearMaps.CompositeMap{Float64}((FC, LinearMap(rand(10,10))))
910
A = 2 * rand(ComplexF64, (10, 10)) .- 1
1011
B = rand(size(A)...)
1112
H = LinearMap(Hermitian(A'A))

test/kronecker.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays
88
LA = LinearMap(A)
99
LB = LinearMap(B)
1010
LK = @inferred kron(LA, LB)
11+
@test_throws AssertionError LinearMaps.KroneckerMap{Float64}((LA, LB))
1112
@test @inferred size(LK) == size(K)
1213
@test LinearMaps.MulStyle(LK) === LinearMaps.ThreeArg()
1314
for i in (1, 2)

test/linearcombination.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools
1010
@test run(b, samples=3).allocs == 0
1111
n = 10
1212
L = sum(fill(CS!, n))
13+
@test_throws AssertionError LinearMaps.LinearCombination{Float64}((CS!, CS!))
1314
@test mul!(u, L, v) n * cumsum(v)
1415
b = @benchmarkable mul!($u, $L, $v, 2, 2)
1516
@test run(b, samples=5).allocs <= 1

0 commit comments

Comments
 (0)