Skip to content

Commit ca93cf0

Browse files
committed
make BlockMap/UniformScalingMap typing consistent with rest, inference tests
1 parent 681f0dc commit ca93cf0

File tree

3 files changed

+51
-61
lines changed

3 files changed

+51
-61
lines changed

src/blockmap.jl

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
struct BlockMap{T,As<:Tuple{Vararg{LinearMap}},Rs<:Tuple{Vararg{Int}}} <: LinearMap{T}
22
maps::As
33
rows::Rs
4-
function BlockMap(maps::R, rows::S) where {T, R<:Tuple{Vararg{LinearMap{T}}}, S<:Tuple{Vararg{Int}}}
5-
new{T,R,S}(maps, rows)
4+
function BlockMap{T,R,S}(As::R, rows::S) where {T, R<:Tuple{Vararg{LinearMap}}, S<:Tuple{Vararg{Int}}}
5+
for A in As
6+
promote_type(T, eltype(A)) == T || throw(InexactError())
7+
end
8+
new{T,R,S}(As, rows)
69
end
710
end
811

12+
BlockMap{T}(maps::As, rows::S) where {T,As<:Tuple{Vararg{LinearMap}},S} = BlockMap{T,As,S}(maps, rows)
13+
914
firstindices(maps::Tuple{Vararg{LinearMap}}, dim) = cumsum([1, map(m -> size(m, dim), maps)...,])
1015

1116
function check_dims(maps::Tuple{Vararg{LinearMap}}, k)
@@ -42,12 +47,6 @@ function Base.hcat(As::Union{LinearMap,UniformScaling}...)
4247
T = promote_type(map(eltype, As)...)
4348
nbc = length(As)
4449

45-
for A in As
46-
if !(A isa UniformScaling)
47-
eltype(A) == T || throw(ArgumentError("eltype mismatch in hcat of linear maps"))
48-
end
49-
end
50-
5150
nrows = 0
5251
# find first non-UniformScaling to detect number of rows
5352
for A in As
@@ -60,7 +59,7 @@ function Base.hcat(As::Union{LinearMap,UniformScaling}...)
6059

6160
maps = promote_to_lmaps(ntuple(i->nrows, nbc), 1, T, As...)
6261
check_dims(maps, 1)
63-
return BlockMap(maps, (length(As),))
62+
return BlockMap{T}(maps, (length(As),))
6463
end
6564

6665
############
@@ -71,12 +70,6 @@ function Base.vcat(As::Union{LinearMap,UniformScaling}...)
7170
T = promote_type(map(eltype, As)...)
7271
nbr = length(As)
7372

74-
for A in As
75-
if !(A isa UniformScaling)
76-
eltype(A) == T || throw(ArgumentError("eltype type mismatch in vcat of linear maps"))
77-
end
78-
end
79-
8073
ncols = 0
8174
# find first non-UniformScaling to detect number of columns
8275
for A in As
@@ -89,7 +82,7 @@ function Base.vcat(As::Union{LinearMap,UniformScaling}...)
8982

9083
maps = promote_to_lmaps(ntuple(i->ncols, nbr), 1, T, As...)
9184
check_dims(maps, 2)
92-
return BlockMap(maps, ntuple(i->1, length(As)))
85+
return BlockMap{T}(maps, ntuple(i->1, length(As)))
9386
end
9487

9588
############
@@ -148,7 +141,7 @@ function Base.hvcat(rows::NTuple{nr,Int}, As::Union{LinearMap,UniformScaling}...
148141
end
149142
end
150143

151-
return BlockMap(promote_to_lmaps(n, 1, T, As...), rows)
144+
return BlockMap{T}(promote_to_lmaps(n, 1, T, As...), rows)
152145
end
153146

154147
promote_to_lmaps_(n::Int, ::Type{T}, J::UniformScaling) where {T} = UniformScalingMap(convert(T, J.λ), n)

src/uniformscalingmap.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct UniformScalingMap{T} <: LinearMap{T} # T will be determined from maps to which this is added
1+
struct UniformScalingMap{T} <: LinearMap{T}
22
λ::T
33
M::Int
44
end
@@ -38,11 +38,7 @@ At_mul_B!(y::AbstractVector, A::UniformScalingMap, x::AbstractVector) = A_mul_B!
3838
Ac_mul_B!(y::AbstractVector, A::UniformScalingMap, x::AbstractVector) = A_mul_B!(y, adjoint(A), x)
3939

4040
# combine LinearMap and UniformScaling objects in linear combinations
41-
Base.:(+)(A1::LinearMap, A2::UniformScaling) =
42-
A1 + UniformScalingMap(convert(promote_type(eltype(A1), eltype(A2)), A2.λ), size(A1, 1))
43-
Base.:(+)(A1::UniformScaling, A2::LinearMap) =
44-
UniformScalingMap(convert(promote_type(eltype(A1), eltype(A2)), A1.λ), size(A2, 1)) + A2
45-
Base.:(-)(A1::LinearMap, A2::UniformScaling) =
46-
A1 - UniformScalingMap(convert(promote_type(eltype(A1), eltype(A2)), A2.λ), size(A1, 1))
47-
Base.:(-)(A1::UniformScaling, A2::LinearMap) =
48-
UniformScalingMap(convert(promote_type(eltype(A1), eltype(A2)), A1.λ), size(A2, 1)) - A2
41+
Base.:(+)(A1::LinearMap, A2::UniformScaling) = A1 + UniformScalingMap(A2.λ, size(A1, 1))
42+
Base.:(+)(A1::UniformScaling, A2::LinearMap) = UniformScalingMap(A1.λ, size(A2, 1)) + A2
43+
Base.:(-)(A1::LinearMap, A2::UniformScaling) = A1 - UniformScalingMap(A2.λ, size(A1, 1))
44+
Base.:(-)(A1::UniformScaling, A2::LinearMap) = UniformScalingMap(A1.λ, size(A2, 1)) - A2

test/runtests.jl

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ V = rand(ComplexF64, 10, 3)
1111
W = rand(ComplexF64, 20, 3)
1212
α = rand()
1313
β = rand()
14-
M = LinearMap(A)
15-
N = LinearMap(M)
14+
M = @inferred LinearMap(A)
15+
N = @inferred LinearMap(M)
1616

1717
@testset "LinearMaps.jl" begin
1818
@test eltype(M) == eltype(A)
@@ -74,7 +74,7 @@ end
7474
@test @inferred LinearMap(M')' * v == A * v
7575
@test @inferred transpose(transpose(M)) == M
7676
@test (M')' == M
77-
Mherm = LinearMap(A'A)
77+
Mherm = @inferred LinearMap(A'A)
7878
@test @inferred ishermitian(Mherm)
7979
@test @inferred !issymmetric(Mherm)
8080
@test @inferred !issymmetric(transpose(Mherm))
@@ -98,11 +98,11 @@ end
9898
@test @inferred mul!(copy(V), transpose(M), W) transpose(A) * W
9999
@test @inferred mul!(copy(V), adjoint(M), W) A' * W
100100

101-
B = LinearMap(Symmetric(rand(10, 10)))
101+
B = @inferred LinearMap(Symmetric(rand(10, 10)))
102102
@test transpose(B) == B
103103
@test B == transpose(B)
104104

105-
B = LinearMap(Hermitian(rand(ComplexF64, 10, 10)))
105+
B = @inferred LinearMap(Hermitian(rand(ComplexF64, 10, 10)))
106106
@test adjoint(B) == B
107107
@test B == B'
108108
end
@@ -121,40 +121,40 @@ end
121121
end
122122
return w
123123
end
124-
MyFT = LinearMap{ComplexF64}(myft, N) / sqrt(N)
124+
MyFT = @inferred LinearMap{ComplexF64}(myft, N) / sqrt(N)
125125
U = Matrix(MyFT) # will be a unitary matrix
126126
@test @inferred U'U Matrix{eltype(U)}(I, N, N)
127127

128-
CS = LinearMap(cumsum, 2)
128+
CS = @inferred LinearMap(cumsum, 2)
129129
@test size(CS) == (2, 2)
130130
@test @inferred !issymmetric(CS)
131131
@test @inferred !ishermitian(CS)
132132
@test @inferred !isposdef(CS)
133133
@test @inferred !(LinearMaps.ismutating(CS))
134134
@test @inferred Matrix(CS) == [1. 0.; 1. 1.]
135135
@test @inferred Array(CS) == [1. 0.; 1. 1.]
136-
CS = LinearMap(cumsum, 10; ismutating=false)
136+
CS = @inferred LinearMap(cumsum, 10; ismutating=false)
137137
v = rand(10)
138138
cv = cumsum(v)
139139
@test CS * v == cv
140140
@test *(CS, v) == cv
141141
@test_throws ErrorException CS' * v
142-
CS = LinearMap(cumsum, x -> cumsum(reverse(x)), 10; ismutating=false)
142+
CS = @inferred LinearMap(cumsum, x -> cumsum(reverse(x)), 10; ismutating=false)
143143
cv = cumsum(v)
144144
@test @inferred CS * v == cv
145145
@test @inferred *(CS, v) == cv
146146
@test @inferred CS' * v == cumsum(reverse(v))
147147
@test @inferred mul!(similar(v), transpose(CS), v) == cumsum(reverse(v))
148148

149-
CS! = LinearMap(cumsum!, 10; ismutating=true)
149+
CS! = @inferred LinearMap(cumsum!, 10; ismutating=true)
150150
@test @inferred LinearMaps.ismutating(CS!)
151151
@test @inferred CS! * v == cv
152152
@test @inferred *(CS!, v) == cv
153153
@test @inferred mul!(similar(v), CS!, v) == cv
154154
@test_throws ErrorException CS!'v
155155
@test_throws ErrorException transpose(CS!) * v
156156

157-
CS! = LinearMap{ComplexF64}(cumsum!, 10; ismutating=true)
157+
CS! = @inferred LinearMap{ComplexF64}(cumsum!, 10; ismutating=true)
158158
v = rand(ComplexF64, 10)
159159
cv = cumsum(v)
160160
@test @inferred LinearMaps.ismutating(CS!)
@@ -173,17 +173,17 @@ end
173173
@test @inferred mul!(similar(v), adjoint(CS), v) == cumsum(reverse(v))
174174

175175
# Test fallback methods:
176-
L = LinearMap(x -> x, x -> x, 10)
176+
L = @inferred LinearMap(x -> x, x -> x, 10)
177177
v = randn(10)
178178
@test @inferred (2 * L)' * v 2 * v
179179
@test @inferred transpose(2 * L) * v 2 * v
180-
L = LinearMap{ComplexF64}(x -> x, x -> x, 10)
180+
L = @inferred LinearMap{ComplexF64}(x -> x, x -> x, 10)
181181
v = rand(ComplexF64, 10)
182182
@test @inferred (2 * L)' * v 2 * v
183183
@test @inferred transpose(2 * L) * v 2 * v
184184
end
185185

186-
CS! = LinearMap(cumsum!, 10; ismutating=true)
186+
CS! = @inferred LinearMap(cumsum!, 10; ismutating=true)
187187
v = rand(10)
188188
u = similar(v)
189189
b = @benchmarkable mul!(u, CS!, v)
@@ -196,9 +196,9 @@ b = @benchmarkable mul!(u, L, v)
196196

197197
A = 2 * rand(ComplexF64, (10, 10)) .- 1
198198
B = rand(size(A)...)
199-
M = LinearMap(A)
200-
N = LinearMap(B)
201-
LC = M + N
199+
M = @inferred LinearMap(A)
200+
N = @inferred LinearMap(B)
201+
LC = @inferred M + N
202202
v = rand(ComplexF64, 10)
203203
w = similar(v)
204204
b = @benchmarkable mul!(w, M, v)
@@ -251,11 +251,11 @@ Base.:(*)(A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}, v::Vector) = A.f
251251
mul!(y::Vector, A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}, x::Vector) = copyto!(y, *(A, x))
252252

253253
@testset "composition" begin
254-
F = LinearMap(cumsum, 10; ismutating=false)
254+
F = @inferred LinearMap(cumsum, 10; ismutating=false)
255255
A = 2 * rand(ComplexF64, (10, 10)) .- 1
256256
B = rand(size(A)...)
257-
M = 1 * LinearMap(A)
258-
N = LinearMap(B)
257+
M = @inferred 1 * LinearMap(A)
258+
N = @inferred LinearMap(B)
259259
@test @inferred (F * F) * v == @inferred F * (F * v)
260260
@test @inferred (F * A) * v == @inferred F * (A * v)
261261
@test @inferred (A * F) * v == @inferred A * (F * v)
@@ -276,20 +276,21 @@ mul!(y::Vector, A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}, x::Vector)
276276
@test @inferred transpose(M * F) == @inferred transpose(F) * transpose(M)
277277
@test @inferred (4*((-3*M)*2)) == @inferred -12M*2
278278
@test @inferred (4*((3*(-M))*2)*(-5)) == @inferred -12M*(-10)
279-
L = 3 * F + 1im * A + F * M' * F
279+
L = @inferred 3 * F + 1im * A + F * M' * F
280280
LF = 3 * Matrix(F) + 1im * A + Matrix(F) * Matrix(M)' * Matrix(F)
281281
@test Array(L) LF
282282
R1 = rand(ComplexF64, 10, 10); L1 = LinearMap(R1)
283283
R2 = rand(ComplexF64, 10, 10); L2 = LinearMap(R2)
284284
R3 = rand(ComplexF64, 10, 10); L3 = LinearMap(R3)
285285
CompositeR = prod(R -> LinearMap(R), [R1, R2, R3])
286-
@test transpose(CompositeR) == transpose(L3) * transpose(L2) * transpose(L1)
287-
@test adjoint(CompositeR) == L3' * L2' * L1'
288-
@test adjoint(adjoint((CompositeR))) == CompositeR
286+
@test @inferred L1 * L2 * L3 == CompositeR
287+
@test @inferred transpose(CompositeR) == transpose(L3) * transpose(L2) * transpose(L1)
288+
@test @inferred adjoint(CompositeR) == L3' * L2' * L1'
289+
@test @inferred adjoint(adjoint((CompositeR))) == CompositeR
289290
@test transpose(transpose((CompositeR))) == CompositeR
290-
Lt = transpose(LinearMap(CompositeR))
291+
Lt = @inferred transpose(LinearMap(CompositeR))
291292
@test Lt * v transpose(R3) * transpose(R2) * transpose(R1) * v
292-
Lc = adjoint(LinearMap(CompositeR))
293+
Lc = @inferred adjoint(LinearMap(CompositeR))
293294
@test Lc * v R3' * R2' * R1' * v
294295

295296
# test inplace operations
@@ -360,9 +361,9 @@ A = rand(10, 20)
360361
B = rand(ComplexF64, 10, 20)
361362
SA = A'A + I
362363
SB = B'B + I
363-
L = LinearMap{Float64}(A)
364-
MA = LinearMap(SA)
365-
MB = LinearMap(SB)
364+
L = @inferred LinearMap{Float64}(A)
365+
MA = @inferred LinearMap(SA)
366+
MB = @inferred LinearMap(SB)
366367
@testset "wrapped maps" begin
367368
@test size(L) == size(A)
368369
@test @inferred !issymmetric(L)
@@ -375,12 +376,12 @@ end
375376
A = 2 * rand(ComplexF64, (10, 10)) .- 1
376377
B = rand(size(A)...)
377378
M = @inferred 1 * LinearMap(A)
378-
N = LinearMap(B)
379-
LC = M + N
379+
N = @inferred LinearMap(B)
380+
LC = @inferred M + N
380381
v = rand(ComplexF64, 10)
381382
w = similar(v)
382383
@testset "identity/scaling map" begin
383-
Id = LinearMaps.UniformScalingMap(1, 10)
384+
Id = @inferred LinearMaps.UniformScalingMap(1, 10)
384385
@test_throws ErrorException LinearMaps.UniformScalingMap(1, 10, 20)
385386
@test_throws ErrorException LinearMaps.UniformScalingMap(1, (10, 20))
386387
@test size(Id) == (10, 10)
@@ -538,16 +539,16 @@ end
538539
@test size(L) == size(A)
539540
@test L * x A * x
540541
@test Matrix(L) A
541-
Lt = transform(L)
542+
Lt = @inferred transform(L)
542543
@test Lt isa LinearMaps.LinearMap{elty}
543544
@test Lt * x transform(A) * x
544-
Lt = transform(LinearMap(L))
545+
Lt = @inferred transform(LinearMap(L))
545546
@test Lt * x transform(A) * x
546547
@test Matrix(Lt) Matrix(transform(A))
547548
A21 = rand(elty, 10, 10)
548549
A = [I A12; A21 I]
549550
L = [I LinearMap(A12); LinearMap(A21) I]
550-
Lt = transform(L)
551+
Lt = @inferred transform(L)
551552
@test Lt isa LinearMaps.LinearMap{elty}
552553
@test Lt * x transform(A) * x
553554
@test Matrix(Lt) Matrix(transform(A))

0 commit comments

Comments
 (0)