11module KroneckerArraysBlockSparseArraysExt
22
3- using BlockArrays: Block
4- using BlockSparseArrays: BlockIndexVector, GenericBlockIndex
5- using KroneckerArrays: CartesianPair, CartesianProduct
6- function Base. getindex (
7- b:: Block{N} ,
8- I:: Vararg{Union{CartesianPair, CartesianProduct}, N}
9- ) where {N}
10- return GenericBlockIndex (b, I)
11- end
12- function Base. getindex (b:: Block{N} , I:: Vararg{CartesianProduct, N} ) where {N}
13- return BlockIndexVector (b, I)
14- end
3+ using KroneckerArrays: KroneckerArrays, KroneckerArray, KroneckerVector,
4+ CartesianPair, CartesianProduct, CartesianProductUnitRange,
5+ kroneckerfactors, ⊗ , isactive, cartesianrange
6+ using BlockArrays: BlockArrays, Block, AbstractBlockedUnitRange, mortar
7+ using BlockSparseArrays: BlockSparseArrays, BlockIndexVector, GenericBlockIndex, ZeroBlocks,
8+ blockrange, eachblockaxis, mortar_axis
9+ using DiagonalArrays: ShapeInitializer
1510
16- using BlockSparseArrays: BlockSparseArrays, blockrange
17- using KroneckerArrays: CartesianPair, CartesianProduct, cartesianrange
18- function BlockSparseArrays. blockrange (bs:: Vector{<:CartesianPair} )
19- return blockrange (map (cartesianrange, bs))
20- end
21- function BlockSparseArrays. blockrange (bs:: Vector{<:CartesianProduct} )
22- return blockrange (map (cartesianrange, bs))
23- end
2411
25- using BlockArrays: BlockArrays, mortar
26- using BlockSparseArrays: blockrange
27- using KroneckerArrays: CartesianProductUnitRange
12+ Base. getindex (b:: Block{N} , I:: Vararg{Union{CartesianPair, CartesianProduct}, N} ) where {N} =
13+ GenericBlockIndex (b, I)
14+ Base. getindex (b:: Block{N} , I:: Vararg{CartesianProduct, N} ) where {N} =
15+ BlockIndexVector (b, I)
16+
17+ BlockSparseArrays. blockrange (bs:: Vector{<:CartesianPair} ) = blockrange (map (cartesianrange, bs))
18+ BlockSparseArrays. blockrange (bs:: Vector{<:CartesianProduct} ) = blockrange (map (cartesianrange, bs))
19+
2820# Makes sure that `mortar` results in a `BlockVector` with the correct
2921# axes, otherwise the axes would not preserve the Kronecker structure.
3022# This is helpful when indexing `BlockUnitRange`, for example:
3123# https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.7.1/src/blockaxis.jl#L540-L547
32- function BlockArrays. mortar (blocks:: AbstractVector{<:CartesianProductUnitRange} )
33- return mortar (blocks, (blockrange (map (Base. axes1, blocks)),))
34- end
24+ BlockArrays. mortar (blocks:: AbstractVector{<:CartesianProductUnitRange} ) =
25+ mortar (blocks, (blockrange (map (Base. axes1, blocks)),))
3526
36- using BlockArrays: AbstractBlockedUnitRange
37- using BlockSparseArrays: Block, ZeroBlocks, eachblockaxis, mortar_axis
38- using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗ , arg1, arg2, isactive
3927
40- function KroneckerArrays. arg1 (r:: AbstractBlockedUnitRange )
41- return mortar_axis (arg1 .(eachblockaxis (r)))
42- end
43- function KroneckerArrays. arg2 (r:: AbstractBlockedUnitRange )
44- return mortar_axis (arg2 .(eachblockaxis (r)))
45- end
28+ KroneckerArrays. kroneckerfactors (r:: AbstractBlockedUnitRange , i:: Int ) =
29+ mortar_axis (kroneckerfactors .(eachblockaxis (r), i))
30+ KroneckerArrays. kroneckerfactors (r:: AbstractBlockedUnitRange ) =
31+ (kroneckerfactors (r, 1 ), kroneckerfactors (r, 2 ))
4632
47- function block_axes (
48- ax:: NTuple{N, AbstractUnitRange{<:Integer}} , I:: Vararg{Block{1}, N}
49- ) where {N}
33+ function block_axes (ax:: NTuple{N, AbstractUnitRange{<:Integer}} , I:: Vararg{Block{1}, N} ) where {N}
5034 return ntuple (N) do d
5135 return only (axes (ax[d][I[d]]))
5236 end
5337end
54- function block_axes (ax:: NTuple{N, AbstractUnitRange{<:Integer}} , I:: Block{N} ) where {N}
55- return block_axes (ax, Tuple (I)... )
56- end
57-
58- using DiagonalArrays: ShapeInitializer
38+ block_axes (ax:: NTuple{N, AbstractUnitRange{<:Integer}} , I:: Block{N} ) where {N} =
39+ block_axes (ax, Tuple (I)... )
5940
6041# # TODO : Is this needed?
6142function Base. getindex (
6243 a:: ZeroBlocks{N, KroneckerArray{T, N, A1, A2}} , I:: Vararg{Int, N}
6344 ) where {T, N, A1 <: AbstractArray{T, N} , A2 <: AbstractArray{T, N} }
64- ax_a1 = map (arg1, a. parentaxes)
65- ax_a2 = map (arg2, a. parentaxes)
66- block_ax_a1 = arg1 .(block_axes (a. parentaxes, Block (I)))
67- block_ax_a2 = arg2 .(block_axes (a. parentaxes, Block (I)))
45+ ax_a1 = kroneckerfactors .( a. parentaxes, 1 )
46+ ax_a2 = kroneckerfactors .( a. parentaxes, 2 )
47+ block_ax_a1 = kroneckerfactors .(block_axes (a. parentaxes, Block (I)), 1 )
48+ block_ax_a2 = kroneckerfactors .(block_axes (a. parentaxes, Block (I)), 2 )
6849 # TODO : Is this a good definition? It is similar to
6950 # the definition of `similar` and `adapt_structure`.
7051 return if isactive (A1) == isactive (A2)
@@ -76,10 +57,7 @@ function Base.getindex(
7657 end
7758end
7859
79- using BlockSparseArrays: BlockSparseArrays
80- using KroneckerArrays: KroneckerArrays, KroneckerVector
81- function BlockSparseArrays. to_truncated_indices (values:: KroneckerVector , I)
82- return KroneckerArrays. to_truncated_indices (values, I)
83- end
60+ BlockSparseArrays. to_truncated_indices (values:: KroneckerVector , I) =
61+ KroneckerArrays. to_truncated_indices (values, I)
8462
8563end
0 commit comments