Skip to content

Commit 02c7dea

Browse files
authored
More general ZeroBlocks definition (#38)
1 parent 2ffe0a0 commit 02c7dea

File tree

2 files changed

+22
-44
lines changed

2 files changed

+22
-44
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.29"
4+
version = "0.1.30"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl

Lines changed: 21 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,8 @@ end
2525

2626
using BlockArrays: AbstractBlockedUnitRange
2727
using BlockSparseArrays: Block, ZeroBlocks, eachblockaxis, mortar_axis
28-
using DerivableInterfaces: zero!
29-
using FillArrays: Eye
30-
using KroneckerArrays:
31-
KroneckerArrays,
32-
EyeEye,
33-
EyeKronecker,
34-
KroneckerEye,
35-
KroneckerMatrix,
36-
,
37-
arg1,
38-
arg2,
39-
_similar
28+
using KroneckerArrays: KroneckerArrays, KroneckerArray, , arg1, arg2, _similar
29+
using BlockSparseArrays.TypeParameterAccessors: unwrap_array_type
4030

4131
function KroneckerArrays.arg1(r::AbstractBlockedUnitRange)
4232
return mortar_axis(arg1.(eachblockaxis(r)))
@@ -58,41 +48,29 @@ end
5848

5949
## TODO: Is this needed?
6050
function Base.getindex(
61-
a::ZeroBlocks{2,KroneckerMatrix{T,A,B}}, I::Vararg{Int,2}
62-
) where {T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}}
51+
a::ZeroBlocks{N,KroneckerArray{T,N,A,B}}, I::Vararg{Int,N}
52+
) where {T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}}
6353
ax_a1 = map(arg1, a.parentaxes)
64-
a1 = ZeroBlocks{2,A}(ax_a1)[I...]
6554
ax_a2 = map(arg2, a.parentaxes)
66-
a2 = ZeroBlocks{2,B}(ax_a2)[I...]
67-
return a1 a2
68-
end
69-
function Base.getindex(
70-
a::ZeroBlocks{2,EyeKronecker{T,A,B}}, I::Vararg{Int,2}
71-
) where {T,A<:Eye{T},B<:AbstractMatrix{T}}
72-
block_ax_a1 = arg1.(block_axes(a.parentaxes, Block(I)))
73-
a1 = _similar(A, block_ax_a1)
74-
75-
ax_a2 = arg2.(a.parentaxes)
76-
a2 = ZeroBlocks{2,B}(ax_a2)[I...]
77-
78-
return a1 a2
79-
end
80-
function Base.getindex(
81-
a::ZeroBlocks{2,KroneckerEye{T,A,B}}, I::Vararg{Int,2}
82-
) where {T,A<:AbstractMatrix{T},B<:Eye{T}}
83-
ax_a1 = arg1.(a.parentaxes)
84-
a1 = ZeroBlocks{2,A}(ax_a1)[I...]
85-
86-
block_ax_a2 = arg2.(block_axes(a.parentaxes, Block(I)))
87-
a2 = _similar(B, block_ax_a2)
88-
55+
# TODO: Instead of mutability, maybe have a trait like
56+
# `isstructural` or `isdata`.
57+
ismut1 = ismutabletype(unwrap_array_type(A))
58+
ismut2 = ismutabletype(unwrap_array_type(B))
59+
(ismut1 || ismut2) || error("Can't get zero block.")
60+
a1 = if ismut1
61+
ZeroBlocks{N,A}(ax_a1)[I...]
62+
else
63+
block_ax_a1 = arg1.(block_axes(a.parentaxes, Block(I)))
64+
_similar(A, block_ax_a1)
65+
end
66+
a2 = if ismut2
67+
ZeroBlocks{N,B}(ax_a2)[I...]
68+
else
69+
block_ax_a2 = arg2.(block_axes(a.parentaxes, Block(I)))
70+
a2 = _similar(B, block_ax_a2)
71+
end
8972
return a1 a2
9073
end
91-
function Base.getindex(
92-
a::ZeroBlocks{2,EyeEye{T,A,B}}, I::Vararg{Int,2}
93-
) where {T,A<:Eye{T},B<:Eye{T}}
94-
return error("Not implemented.")
95-
end
9674

9775
using BlockSparseArrays: BlockSparseArrays
9876
using KroneckerArrays: KroneckerArrays, KroneckerVector

0 commit comments

Comments
 (0)