Skip to content

Commit e5a383f

Browse files
authored
Allow BlockMap construction with matrices (#71)
1 parent 3031bc0 commit e5a383f

File tree

4 files changed

+50
-61
lines changed

4 files changed

+50
-61
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ julia = "1"
1111

1212
[extras]
1313
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
14+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"
1617
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1718
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1819

1920
[targets]
20-
test = ["LinearAlgebra", "SparseArrays", "Test", "BenchmarkTools", "Quaternions"]
21+
test = ["LinearAlgebra", "SparseArrays", "Test", "BenchmarkTools", "InteractiveUtils", "Quaternions"]

src/blockmap.jl

Lines changed: 27 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@ BlockMap{T}(maps::As, rows::S) where {T,As<:Tuple{Vararg{LinearMap}},S} = BlockM
1616

1717
MulStyle(A::BlockMap) = MulStyle(A.maps...)
1818

19-
function check_dim(A::LinearMap, dim, n)
20-
n == size(A, dim) || throw(DimensionMismatch("Expected $n, got $(size(A, dim))"))
21-
return nothing
22-
end
23-
2419
"""
2520
rowcolranges(maps, rows)
2621
@@ -51,31 +46,14 @@ end
5146

5247
Base.size(A::BlockMap) = (last(last(A.rowranges)), last(last(A.colranges)))
5348

54-
############
55-
# concatenation
56-
############
57-
58-
for k in 1:8 # is 8 sufficient?
59-
Is = ntuple(n->:($(Symbol(:A,n))::UniformScaling), Val(k-1))
60-
L = :($(Symbol(:A,k))::LinearMap)
61-
args = ntuple(n->Symbol(:A,n), Val(k))
62-
63-
@eval begin
64-
Base.hcat($(Is...), $L, As::Union{LinearMap,UniformScaling}...) = _hcat($(args...), As...)
65-
Base.vcat($(Is...), $L, As::Union{LinearMap,UniformScaling}...) = _vcat($(args...), As...)
66-
Base.hvcat(rows::Tuple{Vararg{Int}}, $(Is...), $L, As::Union{LinearMap,UniformScaling}...) = _hvcat(rows, $(args...), As...)
67-
end
68-
end
69-
7049
############
7150
# hcat
7251
############
7352
"""
74-
hcat(As::Union{LinearMap,UniformScaling}...)::BlockMap
53+
hcat(As::Union{LinearMap,UniformScaling,AbstractVecOrMat}...)::BlockMap
7554
7655
Construct a (lazy) representation of the horizontal concatenation of the arguments.
77-
`UniformScaling` objects are promoted to `LinearMap` automatically. To avoid fallback
78-
to the generic `Base.hcat`, there must be a `LinearMap` object among the first 8 arguments.
56+
All arguments are promoted to `LinearMap`s automatically.
7957
8058
# Examples
8159
```jldoctest; setup=(using LinearMaps)
@@ -90,9 +68,7 @@ julia> L * ones(Int, 6)
9068
6
9169
```
9270
"""
93-
Base.hcat
94-
95-
function _hcat(As::Union{LinearMap,UniformScaling}...)
71+
function Base.hcat(As::Union{LinearMap,UniformScaling,AbstractVecOrMat}...)
9672
T = promote_type(map(eltype, As)...)
9773
nbc = length(As)
9874

@@ -112,11 +88,10 @@ end
11288
# vcat
11389
############
11490
"""
115-
vcat(As::Union{LinearMap,UniformScaling}...)::BlockMap
91+
vcat(As::Union{LinearMap,UniformScaling,AbstractVecOrMat}...)::BlockMap
11692
11793
Construct a (lazy) representation of the vertical concatenation of the arguments.
118-
`UniformScaling` objects are promoted to `LinearMap` automatically. To avoid fallback
119-
to the generic `Base.vcat`, there must be a `LinearMap` object among the first 8 arguments.
94+
All arguments are promoted to `LinearMap`s automatically.
12095
12196
# Examples
12297
```jldoctest; setup=(using LinearMaps)
@@ -134,9 +109,7 @@ julia> L * ones(Int, 3)
134109
3
135110
```
136111
"""
137-
Base.vcat
138-
139-
function _vcat(As::Union{LinearMap,UniformScaling}...)
112+
function Base.vcat(As::Union{LinearMap,UniformScaling,AbstractVecOrMat}...)
140113
T = promote_type(map(eltype, As)...)
141114
nbr = length(As)
142115

@@ -157,12 +130,11 @@ end
157130
# hvcat
158131
############
159132
"""
160-
hvcat(rows::Tuple{Vararg{Int}}, As::Union{LinearMap,UniformScaling}...)::BlockMap
133+
hvcat(rows::Tuple{Vararg{Int}}, As::Union{LinearMap,UniformScaling,AbstractVecOrMat}...)::BlockMap
161134
162135
Construct a (lazy) representation of the horizontal-vertical concatenation of the arguments.
163136
The first argument specifies the number of arguments to concatenate in each block row.
164-
`UniformScaling` objects are promoted to `LinearMap` automatically. To avoid fallback
165-
to the generic `Base.hvcat`, there must be a `LinearMap` object among the first 8 arguments.
137+
All arguments are promoted to `LinearMap`s automatically.
166138
167139
# Examples
168140
```jldoctest; setup=(using LinearMaps)
@@ -185,7 +157,7 @@ julia> L * ones(Int, 6)
185157
"""
186158
Base.hvcat
187159

188-
function _hvcat(rows::Tuple{Vararg{Int}}, As::Union{LinearMap,UniformScaling}...)
160+
function Base.hvcat(rows::Tuple{Vararg{Int}}, As::Union{LinearMap,UniformScaling,AbstractVecOrMat}...)
189161
nr = length(rows)
190162
T = promote_type(map(eltype, As)...)
191163
sum(rows) == length(As) || throw(ArgumentError("mismatch between row sizes and number of arguments"))
@@ -237,6 +209,13 @@ function _hvcat(rows::Tuple{Vararg{Int}}, As::Union{LinearMap,UniformScaling}...
237209
return BlockMap{T}(promote_to_lmaps(n, 1, 1, As...), rows)
238210
end
239211

212+
function check_dim(A, dim, n)
213+
n == size(A, dim) || throw(DimensionMismatch("Expected $n, got $(size(A, dim))"))
214+
return nothing
215+
end
216+
217+
promote_to_lmaps_(n::Int, dim, A::AbstractMatrix) = (check_dim(A, dim, n); LinearMap(A))
218+
promote_to_lmaps_(n::Int, dim, A::AbstractVector) = (check_dim(A, dim, n); LinearMap(reshape(A, length(A), 1)))
240219
promote_to_lmaps_(n::Int, dim, J::UniformScaling) = UniformScalingMap(J.λ, n)
241220
promote_to_lmaps_(n::Int, dim, A::LinearMap) = (check_dim(A, dim, n); A)
242221
promote_to_lmaps(n, k, dim) = ()
@@ -292,11 +271,6 @@ end
292271

293272
Base.:(==)(A::BlockMap, B::BlockMap) = (eltype(A) == eltype(B) && A.maps == B.maps && A.rows == B.rows)
294273

295-
# special transposition behavior
296-
297-
LinearAlgebra.transpose(A::BlockMap) = TransposeMap(A)
298-
LinearAlgebra.adjoint(A::BlockMap) = AdjointMap(A)
299-
300274
############
301275
# multiplication helper functions
302276
############
@@ -310,6 +284,7 @@ function _blockmul!(y, A::BlockMap, x, α, β)
310284
return __blockmul!(MulStyle(A), y, A, x, α, β)
311285
end
312286

287+
# provide one global intermediate storage vector if necessary
313288
__blockmul!(::FiveArg, y, A, x, α, β) = ___blockmul!(y, A, x, α, β, nothing)
314289
__blockmul!(::ThreeArg, y, A, x, α, β) = ___blockmul!(y, A, x, α, β, similar(y))
315290

@@ -401,7 +376,6 @@ for (intype, outtype) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, A
401376
function _unsafe_mul!(y::$outtype, wrapA::$maptype, x::$intype,
402377
α::Number, β::Number)
403378
require_one_based_indexing(y, x)
404-
405379
return _transblockmul!(y, wrapA.lmap, x, α, β, $transform)
406380
end
407381
end
@@ -437,22 +411,24 @@ BlockDiagonalMap{T}(maps::As) where {T,As<:Tuple{Vararg{LinearMap}}} =
437411
BlockDiagonalMap(maps::LinearMap...) =
438412
BlockDiagonalMap{promote_type(map(eltype, maps)...)}(maps)
439413

414+
# since the below methods are more specific than the Base method,
415+
# they would redefine Base/SparseArrays behavior
440416
for k in 1:8 # is 8 sufficient?
441-
Is = ntuple(n->:($(Symbol(:A,n))::AbstractMatrix), Val(k-1))
417+
Is = ntuple(n->:($(Symbol(:A,n))::AbstractVecOrMat), Val(k-1))
442418
# yields (:A1, :A2, :A3, ..., :A(k-1))
443419
L = :($(Symbol(:A,k))::LinearMap)
444420
# yields :Ak
445-
mapargs = ntuple(n -> :(LinearMap($(Symbol(:A,n)))), Val(k-1))
421+
mapargs = ntuple(n -> :($(Symbol(:A,n))), Val(k-1))
446422
# yields (:LinearMap(A1), :LinearMap(A2), ..., :LinearMap(A(k-1)))
447423

448424
@eval begin
449-
function SparseArrays.blockdiag($(Is...), $L, As::Union{LinearMap,AbstractMatrix}...)
450-
return BlockDiagonalMap($(mapargs...), $(Symbol(:A,k)), convert_to_lmaps(As...)...)
425+
function SparseArrays.blockdiag($(Is...), $L, As::Union{LinearMap,AbstractVecOrMat}...)
426+
return BlockDiagonalMap(convert_to_lmaps($(mapargs...))..., $(Symbol(:A,k)), convert_to_lmaps(As...)...)
451427
end
452428

453-
function Base.cat($(Is...), $L, As::Union{LinearMap,AbstractMatrix}...; dims::Dims{2})
429+
function Base.cat($(Is...), $L, As::Union{LinearMap,AbstractVecOrMat}...; dims::Dims{2})
454430
if dims == (1,2)
455-
return BlockDiagonalMap($(mapargs...), $(Symbol(:A,k)), convert_to_lmaps(As...)...)
431+
return BlockDiagonalMap(convert_to_lmaps($(mapargs...))..., $(Symbol(:A,k)), convert_to_lmaps(As...)...)
456432
else
457433
throw(ArgumentError("dims keyword in cat of LinearMaps must be (1,2)"))
458434
end
@@ -461,7 +437,7 @@ for k in 1:8 # is 8 sufficient?
461437
end
462438

463439
"""
464-
blockdiag(As::Union{LinearMap,AbstractMatrix}...)::BlockDiagonalMap
440+
blockdiag(As::Union{LinearMap,AbstractVecOrMat}...)::BlockDiagonalMap
465441
466442
Construct a (lazy) representation of the diagonal concatenation of the arguments.
467443
To avoid fallback to the generic `SparseArrays.blockdiag`, there must be a `LinearMap`
@@ -470,7 +446,7 @@ object among the first 8 arguments.
470446
SparseArrays.blockdiag
471447

472448
"""
473-
cat(As::Union{LinearMap,AbstractMatrix}...; dims=(1,2))::BlockDiagonalMap
449+
cat(As::Union{LinearMap,AbstractVecOrMat}...; dims=(1,2))::BlockDiagonalMap
474450
475451
Construct a (lazy) representation of the diagonal concatenation of the arguments.
476452
To avoid fallback to the generic `Base.cat`, there must be a `LinearMap`

src/conversion.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ function SparseArrays.sparse(A::BlockMap)
132132
convert.(AbstractMatrix, Base.tail(A.maps))...
133133
)
134134
end
135-
Base.Matrix{T}(A::BlockDiagonalMap) where {T} = cat(convert.(Matrix{T}, A.maps)...; dims=(1,2))
135+
Base.Matrix{T}(A::BlockDiagonalMap) where {T} = Base._cat((1,2), convert.(Matrix{T}, A.maps)...)
136136
Base.convert(::Type{AbstractMatrix}, A::BlockDiagonalMap) = sparse(A)
137137
function SparseArrays.sparse(A::BlockDiagonalMap)
138138
return blockdiag(convert.(SparseMatrixCSC, A.maps)...)

test/blockmap.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
using Test, LinearMaps, LinearAlgebra, SparseArrays, BenchmarkTools
1+
using Test, LinearMaps, LinearAlgebra, SparseArrays, BenchmarkTools, InteractiveUtils
22

33
@testset "block maps" begin
44
@testset "hcat" begin
55
for elty in (Float32, ComplexF64), n2 = (0, 20)
66
A11 = rand(elty, 10, 10)
77
A12 = rand(elty, 10, n2)
8+
v = rand(elty, 10)
89
L = @inferred hcat(LinearMap(A11), LinearMap(A12))
910
@test @inferred(LinearMaps.MulStyle(L)) === matrixstyle
1011
@test L isa LinearMaps.BlockMap{elty}
@@ -16,10 +17,14 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays, BenchmarkTools
1617
L = @inferred hcat(LinearMap(A11), LinearMap(A12), LinearMap(A11))
1718
A = [A11 A12 A11]
1819
@test Matrix(L) A
19-
A = [I I I A11 A11 A11]
20-
L = @inferred hcat(I, I, I, LinearMap(A11), LinearMap(A11), LinearMap(A11))
21-
@test L == [I I I LinearMap(A11) LinearMap(A11) LinearMap(A11)]
22-
x = rand(elty, 60)
20+
A = [I I I A11 A11 A11 v]
21+
@test (@which [A11 A11 A11]).module != LinearMaps
22+
@test (@which [I I I A11 A11 A11]).module != LinearMaps
23+
@test (@which hcat(I, I, I)).module != LinearMaps
24+
@test (@which hcat(I, I, I, LinearMap(A11), A11, A11)).module == LinearMaps
25+
L = @inferred hcat(I, I, I, LinearMap(A11), A11, A11, v)
26+
@test L == [I I I LinearMap(A11) LinearMap(A11) LinearMap(A11) LinearMap(reshape(v, :, 1))]
27+
x = rand(elty, 61)
2328
@test L isa LinearMaps.BlockMap{elty}
2429
@test L * x A * x
2530
A11 = rand(elty, 11, 10)
@@ -31,21 +36,24 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays, BenchmarkTools
3136
@testset "vcat" begin
3237
for elty in (Float32, ComplexF64)
3338
A11 = rand(elty, 10, 10)
39+
v = rand(elty, 10)
3440
L = @inferred vcat(LinearMap(A11))
3541
@test L == [LinearMap(A11);]
3642
@test Matrix(L) A11
3743
A21 = rand(elty, 20, 10)
3844
L = @inferred vcat(LinearMap(A11), LinearMap(A21))
3945
@test L isa LinearMaps.BlockMap{elty}
4046
@test @inferred(LinearMaps.MulStyle(L)) === matrixstyle
47+
@test (@which [A11; A21]).module != LinearMaps
4148
A = [A11; A21]
4249
x = rand(10)
4350
@test size(L) == size(A)
4451
@test Matrix(L) A
4552
@test L * x A * x
46-
A = [I; I; I; A11; A11; A11]
47-
L = @inferred vcat(I, I, I, LinearMap(A11), LinearMap(A11), LinearMap(A11))
48-
@test L == [I; I; I; LinearMap(A11); LinearMap(A11); LinearMap(A11)]
53+
A = [I; I; I; A11; A11; A11; v v v v v v v v v v]
54+
@test (@which [I; I; I; A11; A11; A11; v v v v v v v v v v]).module != LinearMaps
55+
L = @inferred vcat(I, I, I, LinearMap(A11), LinearMap(A11), LinearMap(A11), reduce(hcat, fill(v, 10)))
56+
@test L == [I; I; I; LinearMap(A11); LinearMap(A11); LinearMap(A11); reduce(hcat, fill(v, 10))]
4957
x = rand(elty, 10)
5058
@test L isa LinearMaps.BlockMap{elty}
5159
@test L * x A * x
@@ -62,6 +70,7 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays, BenchmarkTools
6270
A21 = rand(elty, 20, 10)
6371
A22 = rand(elty, 20, 20)
6472
A = [A11 A12; A21 A22]
73+
@test (@which [A11 A12; A21 A22]).module != LinearMaps
6574
@inferred hvcat((2,2), LinearMap(A11), LinearMap(A12), LinearMap(A21), LinearMap(A22))
6675
L = [LinearMap(A11) LinearMap(A12); LinearMap(A21) LinearMap(A22)]
6776
@test @inferred(LinearMaps.MulStyle(L)) === matrixstyle
@@ -74,6 +83,7 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays, BenchmarkTools
7483
@test Matrix(L) == A
7584
@test convert(AbstractMatrix, L) == A
7685
A = [I A12; A21 I]
86+
@test (@which [I A12; A21 I]).module != LinearMaps
7787
@inferred hvcat((2,2), I, LinearMap(A12), LinearMap(A21), I)
7888
L = @inferred hvcat((2,2), I, LinearMap(A12), LinearMap(A21), I)
7989
@test L isa LinearMaps.BlockMap{elty}
@@ -173,6 +183,8 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays, BenchmarkTools
173183

174184
# Md = diag(M1, M2, M3, M2, M1) # unsupported so use sparse:
175185
Md = Matrix(blockdiag(sparse.((M1, M2, M3, M2, M1))...))
186+
@test (@which blockdiag(sparse.((M1, M2, M3, M2, M1))...)).module != LinearMaps
187+
@test (@which cat(M1, M2, M3, M2, M1; dims=(1,2))).module != LinearMaps
176188
x = randn(elty, size(Md, 2))
177189
Bd = @inferred blockdiag(L1, L2, L3, L2, L1)
178190
@test Matrix(Bd) == Md

0 commit comments

Comments
 (0)