|
| 1 | +struct BlockMap{T,As<:Tuple{Vararg{LinearMap}},Rs<:Tuple{Vararg{Int}}} <: LinearMap{T} |
| 2 | + maps::As |
| 3 | + rows::Rs |
| 4 | + function BlockMap(maps::R, rows::S) where {T, R<:Tuple{Vararg{LinearMap{T}}}, S<:Tuple{Vararg{Int}}} |
| 5 | + new{T,R,S}(maps, rows) |
| 6 | + end |
| 7 | +end |
| 8 | + |
| 9 | +firstindices(maps::Tuple{Vararg{LinearMap}}, dim) = cumsum([1, map(m -> size(m, dim), maps)...,]) |
| 10 | + |
| 11 | +function check_dims(maps::Tuple{Vararg{LinearMap}}, k) |
| 12 | + n = size(maps[1], k) |
| 13 | + for map in maps |
| 14 | + n == size(map, k) || throw(DimensionMismatch("Expected $n, got $(size(map, k))")) |
| 15 | + end |
| 16 | + return nothing |
| 17 | +end |
| 18 | + |
| 19 | +function Base.size(A::BlockMap) |
| 20 | + as, rows = A.maps, A.rows |
| 21 | + |
| 22 | + nbr = length(rows) # number of block rows |
| 23 | + nc = 0 |
| 24 | + for i in 1:rows[1] |
| 25 | + nc += size(as[i],2) |
| 26 | + end |
| 27 | + |
| 28 | + nr = 0 |
| 29 | + a = 1 |
| 30 | + for i in 1:nbr |
| 31 | + nr += size(as[a],1) |
| 32 | + a += rows[i] |
| 33 | + end |
| 34 | + return nr, nc |
| 35 | +end |
| 36 | + |
| 37 | +############ |
| 38 | +# hcat |
| 39 | +############ |
| 40 | + |
| 41 | +function Base.hcat(As::Union{LinearMap,UniformScaling}...) |
| 42 | + T = promote_type(map(eltype, As)...) |
| 43 | + nbc = length(As) |
| 44 | + |
| 45 | + for A in As |
| 46 | + if !(A isa UniformScaling) |
| 47 | + eltype(A) == T || throw(ArgumentError("eltype mismatch in hcat of linear maps")) |
| 48 | + end |
| 49 | + end |
| 50 | + |
| 51 | + nrows = 0 |
| 52 | + # find first non-UniformScaling to detect number of rows |
| 53 | + for A in As |
| 54 | + if !(A isa UniformScaling) |
| 55 | + nrows = size(A, 1) |
| 56 | + break |
| 57 | + end |
| 58 | + end |
| 59 | + nrows == 0 && throw(ArgumentError("hcat of only UniformScaling-like objects cannot determine the linear map size")) |
| 60 | + |
| 61 | + maps = promote_to_lmaps(ntuple(i->nrows, nbc), 1, T, As...) |
| 62 | + check_dims(maps, 1) |
| 63 | + return BlockMap(maps, (length(As),)) |
| 64 | +end |
| 65 | + |
| 66 | +############ |
| 67 | +# vcat |
| 68 | +############ |
| 69 | + |
| 70 | +function Base.vcat(As::Union{LinearMap,UniformScaling}...) |
| 71 | + T = promote_type(map(eltype, As)...) |
| 72 | + nbr = length(As) |
| 73 | + |
| 74 | + for A in As |
| 75 | + if !(A isa UniformScaling) |
| 76 | + eltype(A) == T || throw(ArgumentError("eltype type mismatch in vcat of linear maps")) |
| 77 | + end |
| 78 | + end |
| 79 | + |
| 80 | + ncols = 0 |
| 81 | + # find first non-UniformScaling to detect number of columns |
| 82 | + for A in As |
| 83 | + if !(A isa UniformScaling) |
| 84 | + ncols = size(A, 2) |
| 85 | + break |
| 86 | + end |
| 87 | + end |
| 88 | + ncols == 0 && throw(ArgumentError("hcat of only UniformScaling-like objects cannot determine the linear map size")) |
| 89 | + |
| 90 | + maps = promote_to_lmaps(ntuple(i->ncols, nbr), 1, T, As...) |
| 91 | + check_dims(maps, 2) |
| 92 | + return BlockMap(maps, ntuple(i->1, length(As))) |
| 93 | +end |
| 94 | + |
| 95 | +############ |
| 96 | +# hvcat |
| 97 | +############ |
| 98 | + |
| 99 | +function Base.hvcat(rows::NTuple{nr,Int}, As::Union{LinearMap,UniformScaling}...) where nr |
| 100 | + T = promote_type(map(eltype, As)...) |
| 101 | + sum(rows) == length(As) || throw(ArgumentError("mismatch between row sizes and number of arguments")) |
| 102 | + n = fill(-1, length(As)) |
| 103 | + needcols = false # whether we also need to infer some sizes from the column count |
| 104 | + j = 0 |
| 105 | + for i in 1:nr # infer UniformScaling sizes from row counts, if possible: |
| 106 | + ni = -1 # number of rows in this block-row, -1 indicates unknown |
| 107 | + for k in 1:rows[i] |
| 108 | + if !isa(As[j+k], UniformScaling) |
| 109 | + na = size(As[j+k], 1) |
| 110 | + ni >= 0 && ni != na && |
| 111 | + throw(DimensionMismatch("mismatch in number of rows")) |
| 112 | + ni = na |
| 113 | + end |
| 114 | + end |
| 115 | + if ni >= 0 |
| 116 | + for k = 1:rows[i] |
| 117 | + n[j+k] = ni |
| 118 | + end |
| 119 | + else # row consisted only of UniformScaling objects |
| 120 | + needcols = true |
| 121 | + end |
| 122 | + j += rows[i] |
| 123 | + end |
| 124 | + if needcols # some sizes still unknown, try to infer from column count |
| 125 | + nc = -1 |
| 126 | + j = 0 |
| 127 | + for i in 1:nr |
| 128 | + nci = 0 |
| 129 | + rows[i] > 0 && n[j+1] == -1 && (j += rows[i]; continue) |
| 130 | + for k = 1:rows[i] |
| 131 | + nci += isa(As[j+k], UniformScaling) ? n[j+k] : size(As[j+k], 2) |
| 132 | + end |
| 133 | + nc >= 0 && nc != nci && throw(DimensionMismatch("mismatch in number of columns")) |
| 134 | + nc = nci |
| 135 | + j += rows[i] |
| 136 | + end |
| 137 | + nc == -1 && throw(ArgumentError("sizes of UniformScalings could not be inferred")) |
| 138 | + j = 0 |
| 139 | + for i in 1:nr |
| 140 | + if rows[i] > 0 && n[j+1] == -1 # this row consists entirely of UniformScalings |
| 141 | + nci = nc ÷ rows[i] |
| 142 | + nci * rows[i] != nc && throw(DimensionMismatch("indivisible UniformScaling sizes")) |
| 143 | + for k = 1:rows[i] |
| 144 | + n[j+k] = nci |
| 145 | + end |
| 146 | + end |
| 147 | + j += rows[i] |
| 148 | + end |
| 149 | + end |
| 150 | + |
| 151 | + return BlockMap(promote_to_lmaps(n, 1, T, As...), rows) |
| 152 | +end |
| 153 | + |
| 154 | +promote_to_lmaps_(n::Int, ::Type{T}, J::UniformScaling) where {T} = UniformScalingMap(convert(T, J.λ), n) |
| 155 | +promote_to_lmaps_(n::Int, ::Type{T}, A::LinearMap{T}) where {T} = A |
| 156 | +promote_to_lmaps(n, k, ::Type) = () |
| 157 | +promote_to_lmaps(n, k, ::Type{T}, A) where {T} = (promote_to_lmaps_(n[k], T, A),) |
| 158 | +promote_to_lmaps(n, k, ::Type{T}, A, B) where {T} = |
| 159 | + (promote_to_lmaps_(n[k], T, A), promote_to_lmaps_(n[k+1], T, B)) |
| 160 | +promote_to_lmaps(n, k, ::Type{T}, A, B, C) where {T} = |
| 161 | + (promote_to_lmaps_(n[k], T, A), promote_to_lmaps_(n[k+1], T, B), promote_to_lmaps_(n[k+2], T, C)) |
| 162 | +promote_to_lmaps(n, k, ::Type{T}, A, B, Cs...) where {T} = |
| 163 | + (promote_to_lmaps_(n[k], T, A), promote_to_lmaps_(n[k+1], T, B), promote_to_lmaps(n, k+2, T, Cs...)...) |
| 164 | + |
| 165 | +############ |
| 166 | +# basic methods |
| 167 | +############ |
| 168 | + |
| 169 | +# function LinearAlgebra.issymmetric(A::BlockMap) |
| 170 | +# m, n = nblocks(A) |
| 171 | +# m == n || return false |
| 172 | +# for i in 1:m, j in i:m |
| 173 | +# if (i == j && !issymmetric(getblock(A, i, i))) |
| 174 | +# return false |
| 175 | +# elseif getblock(A, i, j) != transpose(getblock(A, j, i)) |
| 176 | +# return false |
| 177 | +# end |
| 178 | +# end |
| 179 | +# return true |
| 180 | +# end |
| 181 | +# |
| 182 | +# LinearAlgebra.ishermitian(A::BlockMap{<:Real}) = issymmetric(A) |
| 183 | +# function LinearAlgebra.ishermitian(A::BlockMap) |
| 184 | +# m, n = nblocks(A) |
| 185 | +# m == n || return false |
| 186 | +# for i in 1:m, j in i:m |
| 187 | +# if (i == j && !ishermitian(getblock(A, i, i))) |
| 188 | +# return false |
| 189 | +# elseif getblock(A, i, j) != adjoint(getblock(A, j, i)) |
| 190 | +# return false |
| 191 | +# end |
| 192 | +# end |
| 193 | +# return true |
| 194 | +# end |
| 195 | +# TODO, currently falls back on the generic `false` |
| 196 | +# LinearAlgebra.isposdef(A::BlockMap) |
| 197 | + |
| 198 | +############ |
| 199 | +# comparison of BlockMap objects, sufficient but not necessary |
| 200 | +############ |
| 201 | + |
| 202 | +Base.:(==)(A::BlockMap, B::BlockMap) = (eltype(A) == eltype(B) && A.maps == B.maps && A.rows == B.rows) |
| 203 | + |
| 204 | +# special transposition behavior |
| 205 | + |
| 206 | +LinearAlgebra.transpose(A::BlockMap) = TransposeMap(A) |
| 207 | +LinearAlgebra.adjoint(A::BlockMap) = AdjointMap(A) |
| 208 | + |
| 209 | +############ |
| 210 | +# multiplication with vectors |
| 211 | +############ |
| 212 | + |
| 213 | +function A_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) |
| 214 | + maps, rows = A.maps, A.rows |
| 215 | + mapind = 0 |
| 216 | + yinds = firstindices(maps[cumsum([1, rows...])[1:end-1]], 1) |
| 217 | + @views for rowind in 1:length(rows) |
| 218 | + xinds = firstindices(maps[mapind+1:mapind+rows[rowind]], 2) |
| 219 | + yrow = @views y[yinds[rowind]:(yinds[rowind+1]-1)] |
| 220 | + mapind += 1 |
| 221 | + A_mul_B!(yrow, maps[mapind], x[xinds[1]:xinds[2]-1]) |
| 222 | + for colind in 2:rows[rowind] |
| 223 | + mapind +=1 |
| 224 | + mul!(yrow, maps[mapind], x[xinds[colind]:xinds[colind+1]-1], 1, 1) |
| 225 | + end |
| 226 | + end |
| 227 | + return y |
| 228 | +end |
| 229 | + |
| 230 | +function At_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) |
| 231 | + maps, rows = A.maps, A.rows |
| 232 | + fill!(y, 0) |
| 233 | + mapind = 0 |
| 234 | + xinds = firstindices(maps[cumsum([1, rows...])[1:end-1]], 1) |
| 235 | + # first block row (rowind = 1), fill all of y |
| 236 | + yinds = firstindices(maps[mapind+1:mapind+rows[1]], 2) |
| 237 | + xcol = @views x[xinds[1]:(xinds[2]-1)] |
| 238 | + @views for colind in 1:rows[1] |
| 239 | + mapind +=1 |
| 240 | + A_mul_B!(y[yinds[colind]:yinds[colind+1]-1], transpose(maps[mapind]), xcol) |
| 241 | + end |
| 242 | + # subsequent block rows, add results to corresponding parts of y |
| 243 | + @views for rowind in 2:length(rows) |
| 244 | + yinds = firstindices(maps[mapind+1:mapind+rows[rowind]], 2) |
| 245 | + xcol = @views x[xinds[rowind]:(xinds[rowind+1]-1)] |
| 246 | + for colind in 1:rows[rowind] |
| 247 | + mapind +=1 |
| 248 | + mul!(y[yinds[colind]:yinds[colind+1]-1], transpose(maps[mapind]), xcol, 1, 1) |
| 249 | + end |
| 250 | + end |
| 251 | + return y |
| 252 | +end |
| 253 | + |
| 254 | +function Ac_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) |
| 255 | + maps, rows = A.maps, A.rows |
| 256 | + fill!(y, 0) |
| 257 | + mapind = 0 |
| 258 | + xinds = firstindices(maps[cumsum([1, rows...])[1:end-1]], 1) |
| 259 | + # first block row (rowind = 1), fill all of y |
| 260 | + yinds = firstindices(maps[mapind+1:mapind+rows[1]], 2) |
| 261 | + xcol = @views x[xinds[1]:(xinds[2]-1)] |
| 262 | + @views for colind in 1:rows[1] |
| 263 | + mapind +=1 |
| 264 | + A_mul_B!(y[yinds[colind]:yinds[colind+1]-1], adjoint(maps[mapind]), xcol) |
| 265 | + end |
| 266 | + # subsequent block rows, add results to corresponding parts of y |
| 267 | + @views for rowind in 2:length(rows) |
| 268 | + yinds = firstindices(maps[mapind+1:mapind+rows[rowind]], 2) |
| 269 | + xcol = @views x[xinds[rowind]:(xinds[rowind+1]-1)] |
| 270 | + for colind in 1:rows[rowind] |
| 271 | + mapind +=1 |
| 272 | + mul!(y[yinds[colind]:yinds[colind+1]-1], adjoint(maps[mapind]), xcol, 1, 1) |
| 273 | + end |
| 274 | + end |
| 275 | + return y |
| 276 | +end |
| 277 | + |
| 278 | +############ |
| 279 | +# show methods |
| 280 | +############ |
| 281 | + |
| 282 | +# block2string(b, s) = string(join(map(string, b), '×'), "-blocked ", Base.dims2string(s)) |
| 283 | +# Base.summary(a::BlockMap) = string(block2string(nblocks(a), size(a)), " ", typeof(a)) |
| 284 | +# # _show_typeof(io, a) = show(io, typeof(a)) |
| 285 | +# function Base.summary(io::IO, a::AbstractBlockMap) |
| 286 | +# print(io, block2string(nblocks(a), size(a))) |
| 287 | +# print(io, ' ') |
| 288 | +# _show_typeof(io, a) |
| 289 | +# end |
| 290 | +# function _show_typeof(io::IO, a::AbstractBlockMap{T}) where {T} |
| 291 | +# Base.show_type_name(io, typeof(a).name) |
| 292 | +# print(io, '{') |
| 293 | +# show(io, T) |
| 294 | +# print(io, '}') |
| 295 | +# end |
0 commit comments