Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.10"
version = "0.1.11"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"

[weakdeps]
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"

[extensions]
KroneckerArraysBlockSparseArraysExt = "BlockSparseArrays"

[compat]
Adapt = "4.3.0"
BlockSparseArrays = "0.7.9"
DerivableInterfaces = "0.5.0"
DiagonalArrays = "0.3.5"
FillArrays = "1.13.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module KroneckerArraysBlockSparseArraysExt

using BlockSparseArrays: BlockSparseArrays, blockrange
using KroneckerArrays: CartesianProduct, cartesianrange

function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct})
return blockrange(map(cartesianrange, bs))
end

end
20 changes: 14 additions & 6 deletions src/KroneckerArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ end
const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B}
const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B}

using Adapt: Adapt, adapt
Adapt.adapt_structure(to, a::KroneckerArray) = adapt(to, a.a) ⊗ adapt(to, a.b)

function Base.copy(a::KroneckerArray)
return copy(a.a) ⊗ copy(a.b)
end
Expand Down Expand Up @@ -930,6 +933,11 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}

using Adapt: Adapt, adapt
Adapt.adapt_structure(to, a::SquareEyeKronecker) = a.a ⊗ adapt(to, a.b)
Adapt.adapt_structure(to, a::KroneckerSquareEye) = adapt(to, a.a) ⊗ a.b
Adapt.adapt_structure(to, a::SquareEyeSquareEye) = adapt(to, a.a) ⊗ a.b

# Special case of similar for `SquareEye ⊗ A` and `A ⊗ SquareEye`.
function Base.similar(
a::SquareEyeKronecker,
Expand Down Expand Up @@ -970,22 +978,22 @@ function Base.similar(
end

function Base.similar(
arrayt::Type{<:SquareEyeKronecker{<:Any,<:Any,A}},
arrayt::Type{<:SquareEyeKronecker{T,A,B}},
axs::NTuple{2,CartesianProductUnitRange{<:Integer}},
) where {A}
) where {T,A<:SquareEye{T},B}
ax_a = map(ax -> ax.product.a, axs)
ax_b = map(ax -> ax.product.b, axs)
eye_ax_a = (only(unique(ax_a)),)
return Eye{eltype(arrayt)}(eye_ax_a) ⊗ similar(A, ax_b)
return Eye{T}(eye_ax_a) ⊗ similar(B, ax_b)
end
function Base.similar(
arrayt::Type{<:KroneckerSquareEye{<:Any,A}},
arrayt::Type{<:KroneckerSquareEye{T,A,B}},
axs::NTuple{2,CartesianProductUnitRange{<:Integer}},
) where {A}
) where {T,A,B<:SquareEye{T}}
ax_a = map(ax -> ax.product.a, axs)
ax_b = map(ax -> ax.product.b, axs)
eye_ax_b = (only(unique(ax_b)),)
return similar(A, ax_a) ⊗ Eye{eltype(arrayt)}(eye_ax_b)
return similar(A, ax_a) ⊗ Eye{T}(eye_ax_b)
end
function Base.similar(
arrayt::Type{<:SquareEyeSquareEye}, axs::NTuple{2,CartesianProductUnitRange{<:Integer}}
Expand Down
8 changes: 8 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
Expand All @@ -12,9 +16,13 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[compat]
Adapt = "4"
Aqua = "0.8"
BlockArrays = "1.6"
BlockSparseArrays = "0.7"
DerivableInterfaces = "0.5"
FillArrays = "1"
JLArrays = "0.2"
KroneckerArrays = "0.1"
LinearAlgebra = "1.10"
MatrixAlgebraKit = "0.2"
Expand Down
16 changes: 15 additions & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using Adapt: adapt
using Base.Broadcast: BroadcastStyle, Broadcasted, broadcasted
using DerivableInterfaces: zero!
using FillArrays: Eye
using JLArrays: JLArray
using KroneckerArrays:
KroneckerArrays,
KroneckerArray,
Expand All @@ -17,7 +19,7 @@ using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, norm, pinv, qr, svd,
using StableRNGs: StableRNG
using Test: @test, @test_broken, @test_throws, @testset

const elts = (Float32, Float64, ComplexF32, ComplexF64)
elts = (Float32, Float64, ComplexF32, ComplexF64)
@testset "KroneckerArrays (eltype=$elt)" for elt in elts
p = [1, 2] × [3, 4, 5]
@test length(p) == 6
Expand Down Expand Up @@ -78,6 +80,7 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
@test norm(a) ≈ norm(collect(a))

# Broadcasting
a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
style = KroneckerStyle(BroadcastStyle(typeof(a.a)), BroadcastStyle(typeof(a.b)))
@test BroadcastStyle(typeof(a)) === style
@test_throws "not supported" sin.(a)
Expand All @@ -94,6 +97,7 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
@test_broken copy(bc)

# Mapping
a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
@test_throws "not supported" map(sin, a)
@test_broken map(Base.Fix1(*, 2), a)
a′ = similar(a)
Expand All @@ -120,6 +124,7 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
map!(conj, a′, a)
@test collect(a′) ≈ conj(collect(a))

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
if elt <: Real
@test real(a) == a
else
Expand All @@ -131,6 +136,15 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
@test_throws ArgumentError imag(a)
end

# Adapt
a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
a′ = adapt(JLArray, a)
@test a′ isa KroneckerArray{elt,2,JLArray{elt,2},JLArray{elt,2}}
@test a′.a isa JLArray{elt,2}
@test a′.b isa JLArray{elt,2}
@test Array(a′.a) == a.a
@test Array(a′.b) == a.b

a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3)
@test collect(a) ≈ kron_nd(a.a, a.b)
@test a[1 × 1, 1 × 1, 1 × 1] == a.a[1, 1, 1] * a.b[1, 1, 1]
Expand Down
135 changes: 135 additions & 0 deletions test/test_blocksparsearrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
using Adapt: adapt
using BlockArrays: Block, BlockRange
using BlockSparseArrays:
BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype
using FillArrays: Eye, SquareEye
using JLArrays: JLArray
using KroneckerArrays: KroneckerArray, ⊗, ×
using LinearAlgebra: norm
using MatrixAlgebraKit: svd_compact
using Test: @test, @test_broken, @testset
using TestExtras: @constinferred

elts = (Float32, Float64, ComplexF32)
arrayts = (Array, JLArray)
@testset "BlockSparseArraysExt, KroneckerArray blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in
arrayts,
elt in elts

dev = adapt(arrayt)
r = blockrange([2 × 2, 3 × 3])
d = Dict(
Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)),
Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)),
)
a = dev(blocksparse(d, r, r))
@test_broken sprint(show, a)
@test sprint(show, MIME("text/plain"), a) isa String
@test blocktype(a) === valtype(d)
@test a isa BlockSparseMatrix{elt,valtype(d)}
@test a[Block(1, 1)] == dev(d[Block(1, 1)])
@test a[Block(1, 1)] isa valtype(d)
@test a[Block(2, 2)] == dev(d[Block(2, 2)])
@test a[Block(2, 2)] isa valtype(d)
@test iszero(a[Block(2, 1)])
@test a[Block(2, 1)] == dev(zeros(elt, 3, 2) ⊗ zeros(elt, 3, 2))
@test a[Block(2, 1)] isa valtype(d)
@test iszero(a[Block(1, 2)])
@test a[Block(1, 2)] == dev(zeros(elt, 2, 3) ⊗ zeros(elt, 2, 3))
@test a[Block(1, 2)] isa valtype(d)

b = a * a
@test typeof(b) === typeof(a)
@test Array(b) ≈ Array(a) * Array(a)

b = a + a
@test typeof(b) === typeof(a)
@test Array(b) ≈ Array(a) + Array(a)

b = 3a
@test typeof(b) === typeof(a)
@test Array(b) ≈ 3Array(a)

b = a / 3
@test typeof(b) === typeof(a)
@test Array(b) ≈ Array(a) / 3

@test norm(a) ≈ norm(Array(a))

if arrayt == Array
@test Array(inv(a)) ≈ inv(Array(a))
else
# Broken for JLArray, it seems like `inv` isn't
# type stable.
@test_broken inv(a)
end

# Broken operations
@test_broken exp(a)
@test_broken svd_compact(a)
@test_broken a[Block.(1:2), Block(2)]
end

@testset "BlockSparseArraysExt, SquareEyeKronecker blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in
arrayts,
elt in elts

if arrayt == JLArray
# TODO: Collecting to `Array` is broken for GPU arrays so a lot of tests
# are broken, look into fixing that.
continue
end

dev = adapt(arrayt)
r = blockrange([2 × 2, 3 × 3])
d = Dict(
Block(1, 1) => Eye{elt}(2, 2) ⊗ randn(elt, 2, 2),
Block(2, 2) => Eye{elt}(3, 3) ⊗ randn(elt, 3, 3),
)
a = dev(blocksparse(d, r, r))
@test_broken sprint(show, a)
@test sprint(show, MIME("text/plain"), a) isa String
@test_broken blocktype(a) === valtype(d)
@test_broken a isa BlockSparseMatrix{elt,valtype(d)}
@test a[Block(1, 1)] == dev(d[Block(1, 1)])
@test_broken a[Block(1, 1)] isa valtype(d)
@test a[Block(2, 2)] == dev(d[Block(2, 2)])
@test_broken a[Block(2, 2)] isa valtype(d)
@test iszero(a[Block(2, 1)])
@test a[Block(2, 1)] == dev(zeros(elt, 3, 2) ⊗ zeros(elt, 3, 2))
@test_broken a[Block(2, 1)] isa valtype(d)
@test iszero(a[Block(1, 2)])
@test a[Block(1, 2)] == dev(zeros(elt, 2, 3) ⊗ zeros(elt, 2, 3))
@test_broken a[Block(1, 2)] isa valtype(d)

b = a * a
@test typeof(b) === typeof(a)
@test Array(b) ≈ Array(a) * Array(a)

b = a + a
@test typeof(b) === typeof(a)
@test Array(b) ≈ Array(a) + Array(a)

b = 3a
@test typeof(b) === typeof(a)
@test Array(b) ≈ 3Array(a)

b = a / 3
@test typeof(b) === typeof(a)
@test Array(b) ≈ Array(a) / 3

@test norm(a) ≈ norm(Array(a))

if arrayt == Array
@test Array(inv(a)) ≈ inv(Array(a))
else
# Broken for JLArray, it seems like `inv` isn't
# type stable.
@test_broken inv(a)
end

# Broken operations
# @test_broken exp(a)
@test_broken svd_compact(a)
@test_broken a[Block.(1:2), Block(2)]
end
Loading