Skip to content

Commit f38f818

Browse files
committed
Fix some tests
1 parent d72be3f commit f38f818

File tree

4 files changed

+50
-26
lines changed

4 files changed

+50
-26
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1414
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
15+
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1516

1617
[weakdeps]
1718
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
@@ -37,4 +38,5 @@ MapBroadcast = "0.1.10"
3738
MatrixAlgebraKit = "0.2, 0.3"
3839
TensorAlgebra = "0.3.10"
3940
TensorProducts = "0.1.7"
41+
TypeParameterAccessors = "0.4.2"
4042
julia = "1.10"

ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ end
3636

3737
using BlockArrays: AbstractBlockedUnitRange
3838
using BlockSparseArrays: Block, ZeroBlocks, eachblockaxis, mortar_axis
39-
using KroneckerArrays: KroneckerArrays, KroneckerArray, , arg1, arg2
40-
using BlockSparseArrays.TypeParameterAccessors: unwrap_array_type
39+
using KroneckerArrays: KroneckerArrays, KroneckerArray, , arg1, arg2, isactive
4140

4241
function KroneckerArrays.arg1(r::AbstractBlockedUnitRange)
4342
return mortar_axis(arg1.(eachblockaxis(r)))
@@ -59,28 +58,21 @@ end
5958

6059
## TODO: Is this needed?
6160
function Base.getindex(
62-
a::ZeroBlocks{N,KroneckerArray{T,N,A,B}}, I::Vararg{Int,N}
63-
) where {T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}}
61+
a::ZeroBlocks{N,KroneckerArray{T,N,A1,A2}}, I::Vararg{Int,N}
62+
) where {T,N,A1<:AbstractArray{T,N},A2<:AbstractArray{T,N}}
6463
ax_a1 = map(arg1, a.parentaxes)
6564
ax_a2 = map(arg2, a.parentaxes)
66-
# TODO: Instead of mutability, maybe have a trait like
67-
# `isstructural` or `isdata`.
68-
ismut1 = ismutabletype(unwrap_array_type(A))
69-
ismut2 = ismutabletype(unwrap_array_type(B))
70-
(ismut1 || ismut2) || error("Can't get zero block.")
71-
a1 = if ismut1
72-
ZeroBlocks{N,A}(ax_a1)[I...]
73-
else
74-
block_ax_a1 = arg1.(block_axes(a.parentaxes, Block(I)))
75-
similar(A, block_ax_a1)
65+
block_ax_a1 = arg1.(block_axes(a.parentaxes, Block(I)))
66+
block_ax_a2 = arg2.(block_axes(a.parentaxes, Block(I)))
67+
# TODO: Is this a good definition? It is similar to
68+
# the definition of `similar` and `adapt_structure`.
69+
return if isactive(A1) == isactive(A2)
70+
ZeroBlocks{N,A1}(ax_a1)[I...] ZeroBlocks{N,A2}(ax_a2)[I...]
71+
elseif isactive(A1)
72+
ZeroBlocks{N,A1}(ax_a1)[I...] A2(block_ax_a2)
73+
elseif isactive(A2)
74+
A1(block_ax_a1) ZeroBlocks{N,A2}(ax_a2)[I...]
7675
end
77-
a2 = if ismut2
78-
ZeroBlocks{N,B}(ax_a2)[I...]
79-
else
80-
block_ax_a2 = arg2.(block_axes(a.parentaxes, Block(I)))
81-
a2 = similar(B, block_ax_a2)
82-
end
83-
return a1 a2
8476
end
8577

8678
using BlockSparseArrays: BlockSparseArrays

src/kroneckerarray.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,30 @@
1+
# TODO: Move this to DiagonalArrays.jl.
2+
using DiagonalArrays: DiagonalArrays, _DiagonalArray, DiagonalArray, Unstored
3+
# TODO: Also support size inputs.
4+
function DiagonalArrays.DiagonalArray{T,N,D,U}(
5+
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
6+
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
7+
# TODO: Support these constructors.
8+
# return DiagonalArray{T,N,Diag,Unstored}(Diag((Base.OneTo(minimum(length, ax)),)), Unstored(U(ax)))
9+
# return DiagonalArray{T,N,Diag}(Diag((Base.OneTo(minimum(length, ax)),)), Unstored(U(ax)))
10+
# return DiagonalArray{T,N}(Diag((Base.OneTo(minimum(length, ax)),)), Unstored(U(ax)))
11+
# return DiagonalArray{T}(D((Base.OneTo(minimum(length, ax)),)), Unstored(U(ax)))
12+
# return DiagonalArray(D((Base.OneTo(minimum(length, ax)),)), Unstored(U(ax)))
13+
return _DiagonalArray(D((Base.OneTo(minimum(length, ax)),)), U(ax))
14+
end
15+
116
function unwrap_array(a::AbstractArray)
217
p = parent(a)
318
p a && return a
419
return unwrap_array(p)
520
end
621
isactive(a::AbstractArray) = ismutable(unwrap_array(a))
722

23+
using TypeParameterAccessors: unwrap_array_type
24+
function isactive(arrayt::Type{<:AbstractArray})
25+
return ismutabletype(unwrap_array_type(arrayt))
26+
end
27+
828
# Custom `_convert` works around the issue that
929
# `convert(::Type{<:Diagonal}, ::AbstractMatrix)` isn't defined
1030
# in Julia v1.10 (https://github.com/JuliaLang/julia/pull/48895,
@@ -56,7 +76,17 @@ function mutate_active_args!(f!, f, dest, src)
5676
end
5777

5878
using Adapt: Adapt, adapt
59-
Adapt.adapt_structure(to, a::KroneckerArray) = adapt(to, arg1(a)) adapt(to, arg2(a))
79+
function Adapt.adapt_structure(to, a::KroneckerArray)
80+
# TODO: Is this a good definition? It is similar to
81+
# the definition of `similar`.
82+
return if isactive(arg1(a)) == isactive(arg2(a))
83+
adapt(to, arg1(a)) adapt(to, arg2(a))
84+
elseif isactive(arg1(a))
85+
adapt(to, arg1(a)) arg2(a)
86+
elseif isactive(arg2(a))
87+
arg1(a) adapt(to, arg2(a))
88+
end
89+
end
6090

6191
function Base.copy(a::KroneckerArray)
6292
return copy(arg1(a)) copy(arg2(a))

test/test_blocksparsearrays.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,10 @@ end
161161
elt in elts
162162

163163
dev = adapt(arrayt)
164-
r = @constinferred blockrange([2 × 2, 3 × 3])
164+
r = @constinferred blockrange([2 × 2, 2 × 3])
165165
d = Dict(
166166
Block(1, 1) => δ(elt, (2, 2)) dev(randn(elt, 2, 2)),
167-
Block(2, 2) => δ(elt, (3, 3)) dev(randn(elt, 3, 3)),
167+
Block(2, 2) => δ(elt, (2, 2)) dev(randn(elt, 3, 3)),
168168
)
169169
a = @constinferred dev(blocksparse(d, (r, r)))
170170
@test sprint(show, a) == sprint(show, Array(a))
@@ -176,10 +176,10 @@ end
176176
@test @constinferred(a[Block(2, 2)]) == dev(d[Block(2, 2)])
177177
@test @constinferred(a[Block(2, 2)]) isa valtype(d)
178178
@test @constinferred(iszero(a[Block(2, 1)]))
179-
@test a[Block(2, 1)] == dev(δ(3, 2) zeros(elt, 3, 2))
179+
@test a[Block(2, 1)] == dev(δ(2, 2) zeros(elt, 3, 2))
180180
@test a[Block(2, 1)] isa valtype(d)
181181
@test @constinferred(iszero(a[Block(1, 2)]))
182-
@test a[Block(1, 2)] == dev(δ(2, 3) zeros(elt, 2, 3))
182+
@test a[Block(1, 2)] == dev(δ(2, 2) zeros(elt, 2, 3))
183183
@test a[Block(1, 2)] isa valtype(d)
184184

185185
# Slicing

0 commit comments

Comments
 (0)