Skip to content

Commit 35b5aba

Browse files
authored
Merge branch 'main' into mf/delta_support
2 parents 49460d6 + 2c2d41e commit 35b5aba

File tree

11 files changed

+89
-46
lines changed

11 files changed

+89
-46
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ ci:
33

44
repos:
55
- repo: https://github.com/pre-commit/pre-commit-hooks
6-
rev: v5.0.0
6+
rev: v6.0.0
77
hooks:
88
- id: check-merge-conflict
99
- id: check-toml
@@ -12,6 +12,6 @@ repos:
1212
exclude_types: [markdown] # incompatible with Literate.jl
1313

1414
- repo: "https://github.com/domluna/JuliaFormatter.jl"
15-
rev: v2.1.2
15+
rev: v2.1.6
1616
hooks:
1717
- id: "julia-formatter"

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.27"
4+
version = "0.1.28"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -26,7 +26,7 @@ KroneckerArraysTensorProductsExt = "TensorProducts"
2626
[compat]
2727
Adapt = "4.3"
2828
BlockArrays = "1.6"
29-
BlockSparseArrays = "0.8.1"
29+
BlockSparseArrays = "0.9"
3030
DerivableInterfaces = "0.5"
3131
DiagonalArrays = "0.3.11"
3232
FillArrays = "1.13"
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module KroneckerArraysTensorProductsExt
2+
3+
using KroneckerArrays: CartesianProductOneTo, ×, arg1, arg2, cartesianrange, unproduct
4+
using TensorProducts: TensorProducts, tensor_product
5+
function TensorProducts.tensor_product(a1::CartesianProductOneTo, a2::CartesianProductOneTo)
6+
prod = tensor_product(arg1(a1), arg1(a2)) × tensor_product(arg2(a1), arg2(a2))
7+
range = tensor_product(unproduct(a1), unproduct(a2))
8+
return cartesianrange(prod, range)
9+
end
10+
11+
end

src/cartesianproduct.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ unproduct(r::CartesianProductVector) = getfield(r, :values)
6262
Base.length(a::CartesianProductVector) = length(unproduct(a))
6363
Base.size(a::CartesianProductVector) = (length(a),)
6464
function Base.axes(r::CartesianProductVector)
65-
return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),)
65+
prod = cartesianproduct(r)
66+
prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod)))
67+
return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),)
6668
end
6769
function Base.copy(a::CartesianProductVector)
6870
return CartesianProductVector(copy(cartesianproduct(a)), copy(unproduct(a)))

src/fillarrays/kroneckerarray.jl

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,11 @@ function _convert(::Type{AbstractMatrix{T}}, a::RectDiagonal) where {T}
5454
return RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a))
5555
end
5656

57-
# Like `similar` but preserves `Eye`.
58-
function _similar(a::AbstractArray, elt::Type, ax::Tuple)
59-
return similar(a, elt, ax)
57+
# Like `similar` but preserves `Eye`, `Ones`, etc.
58+
using FillArrays: Ones
59+
function _similar(arrayt::Type{<:Ones}, axs::Tuple)
60+
return Ones{eltype(arrayt)}(axs)
6061
end
61-
function _similar(A::Type{<:AbstractArray}, ax::Tuple)
62-
return similar(A, ax)
63-
end
64-
function _similar(a::AbstractArray, ax::Tuple)
65-
return _similar(a, eltype(a), ax)
66-
end
67-
function _similar(a::AbstractArray, elt::Type)
68-
return _similar(a, elt, axes(a))
69-
end
70-
function _similar(a::AbstractArray)
71-
return _similar(a, eltype(a), axes(a))
72-
end
73-
74-
# Like `similar` but preserves `Eye`.
7562
function _similar(a::Eye, elt::Type, axs::NTuple{2,AbstractUnitRange})
7663
return Eye{elt}(axs)
7764
end

src/kroneckerarray.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,19 @@ function Base.convert(::Type{KroneckerArray{T,N,A,B}}, a::KroneckerArray) where
7171
end
7272

7373
# Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`.
74-
function _similar(a::AbstractArray, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}})
74+
function _similar(a::AbstractArray, elt::Type, axs::Tuple)
7575
return similar(a, elt, axs)
7676
end
77-
function _similar(arrayt::Type{<:AbstractArray}, axs::Tuple{Vararg{AbstractUnitRange}})
77+
function _similar(a::AbstractArray, ax::Tuple)
78+
return _similar(a, eltype(a), ax)
79+
end
80+
function _similar(a::AbstractArray, elt::Type)
81+
return _similar(a, elt, axes(a))
82+
end
83+
function _similar(a::AbstractArray)
84+
return _similar(a, eltype(a), axes(a))
85+
end
86+
function _similar(arrayt::Type{<:AbstractArray}, axs::Tuple)
7887
return similar(arrayt, axs)
7988
end
8089

@@ -164,6 +173,16 @@ Base.collect(a::KroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a)))
164173

165174
Base.zero(a::KroneckerArray) = zero(arg1(a)) zero(arg2(a))
166175

176+
using DerivableInterfaces: DerivableInterfaces, zero!
177+
function DerivableInterfaces.zero!(a::KroneckerArray)
178+
ismut1 = ismutable(arg1(a))
179+
ismut2 = ismutable(arg2(a))
180+
(ismut1 || ismut2) || throw(ArgumentError("Can't zero out immutable KroneckerArray."))
181+
ismut1 && zero!(arg1(a))
182+
ismut2 && zero!(arg2(a))
183+
return a
184+
end
185+
167186
function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N}
168187
return convert(Array{T,N}, collect(a))
169188
end
@@ -412,13 +431,15 @@ _eltype(x) = eltype(x)
412431
_eltype(x::Broadcasted) = Base.promote_op(x.f, _eltype.(x.args)...)
413432

414433
using Base.Broadcast: broadcasted
415-
struct KroneckerBroadcasted{A<:Broadcasted,B<:Broadcasted}
434+
struct KroneckerBroadcasted{A,B}
416435
a::A
417436
b::B
418437
end
419438
arg1(a::KroneckerBroadcasted) = a.a
420439
arg2(a::KroneckerBroadcasted) = a.b
421440
(a::Broadcasted, b::Broadcasted) = KroneckerBroadcasted(a, b)
441+
(a::Broadcasted, b) = KroneckerBroadcasted(a, b)
442+
(a, b::Broadcasted) = KroneckerBroadcasted(a, b)
422443
Broadcast.materialize(a::KroneckerBroadcasted) = copy(a)
423444
Broadcast.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a)
424445
Broadcast.broadcastable(a::KroneckerBroadcasted) = a

src/linearalgebra.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,3 @@ function LinearAlgebra.lq(a::KroneckerArray)
179179
Fb = lq(a.b)
180180
return KroneckerLQ(Fa.L Fb.L, Fa.Q Fb.Q)
181181
end
182-
183-
using DerivableInterfaces: DerivableInterfaces, zero!
184-
function DerivableInterfaces.zero!(a::KroneckerArray)
185-
zero!(a.a)
186-
zero!(a.b)
187-
return a
188-
end

test/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1414
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1515
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1616
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
17+
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
1718
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1819
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
1920

2021
[compat]
2122
Adapt = "4"
2223
Aqua = "0.8"
2324
BlockArrays = "1.6"
24-
BlockSparseArrays = "0.8.1"
25+
BlockSparseArrays = "0.9"
2526
DerivableInterfaces = "0.5"
2627
DiagonalArrays = "0.3.7"
2728
FillArrays = "1"
@@ -33,5 +34,6 @@ MatrixAlgebraKit = "0.2"
3334
SafeTestsets = "0.1"
3435
StableRNGs = "1.0"
3536
Suppressor = "0.2"
37+
TensorProducts = "0.1.7"
3638
Test = "1.10"
3739
TestExtras = "0.3"

test/test_basics.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using KroneckerArrays:
99
KroneckerArray,
1010
KroneckerStyle,
1111
CartesianProductUnitRange,
12+
CartesianProductVector,
1213
,
1314
×,
1415
arg1,
@@ -45,6 +46,14 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
4546
@test r[2 × 2] == 5
4647
@test r[2 × 3] == 6
4748

49+
# CartesianProductUnitRange axes
50+
r = cartesianrange((2:3) × (3:4), 2:5)
51+
@test axes(r) (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),)
52+
53+
# CartesianProductVector axes
54+
r = CartesianProductVector(([2, 4]) × ([3, 5]), [3, 5, 7, 9])
55+
@test axes(r) (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),)
56+
4857
r = @constinferred(cartesianrange(2 × 3, 2:7))
4958
@test r === cartesianrange(Base.OneTo(2) × Base.OneTo(3), 2:7)
5059
@test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3)

test/test_blocksparsearrays.jl

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ arrayts = (Array, JLArray)
2323
Block(1, 1) => dev(randn(elt, 2, 2) randn(elt, 2, 2)),
2424
Block(2, 2) => dev(randn(elt, 3, 3) randn(elt, 3, 3)),
2525
)
26-
a = dev(blocksparse(d, r, r))
26+
a = dev(blocksparse(d, (r, r)))
2727
@test sprint(show, a) isa String
2828
@test sprint(show, MIME("text/plain"), a) isa String
2929
@test blocktype(a) === valtype(d)
@@ -45,7 +45,7 @@ arrayts = (Array, JLArray)
4545
Block(1, 1) => dev(randn(elt, 2, 2) randn(elt, 2, 2)),
4646
Block(2, 2) => dev(randn(elt, 3, 3) randn(elt, 3, 3)),
4747
)
48-
a = dev(blocksparse(d, r, r))
48+
a = dev(blocksparse(d, (r, r)))
4949
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
5050
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
5151
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
@@ -68,7 +68,7 @@ arrayts = (Array, JLArray)
6868
Block(1, 1) => dev(randn(elt, 2, 2) randn(elt, 2, 2)),
6969
Block(2, 2) => dev(randn(elt, 3, 3) randn(elt, 3, 3)),
7070
)
71-
a = dev(blocksparse(d, r, r))
71+
a = dev(blocksparse(d, (r, r)))
7272
i1 = Block(1)[(1:2) × (1:2)]
7373
i2 = Block(2)[(2:3) × (2:3)]
7474
I = mortar([i1, i2])
@@ -83,7 +83,7 @@ arrayts = (Array, JLArray)
8383
Block(1, 1) => dev(randn(elt, 2, 2) randn(elt, 2, 2)),
8484
Block(2, 2) => dev(randn(elt, 3, 3) randn(elt, 3, 3)),
8585
)
86-
a = dev(blocksparse(d, r, r))
86+
a = dev(blocksparse(d, (r, r)))
8787
i1 = Block(1)[(1:2) × (1:2)]
8888
i2 = Block(2)[(2:3) × (2:3)]
8989
I = [i1, i2]
@@ -130,9 +130,12 @@ arrayts = (Array, JLArray)
130130
@test_broken svd_compact(a)
131131
end
132132

133+
b = a[Block.(1:2), Block(2)]
134+
@test b[Block(1)] == a[Block(1, 2)]
135+
@test b[Block(2)] == a[Block(2, 2)]
136+
133137
# Broken operations
134138
@test_broken exp(a)
135-
@test_broken a[Block.(1:2), Block(2)]
136139
end
137140

138141
@testset "BlockSparseArraysExt, EyeKronecker blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in
@@ -145,7 +148,7 @@ end
145148
Block(1, 1) => Eye{elt}(2, 2) dev(randn(elt, 2, 2)),
146149
Block(2, 2) => Eye{elt}(3, 3) dev(randn(elt, 3, 3)),
147150
)
148-
a = @constinferred dev(blocksparse(d, r, r))
151+
a = @constinferred dev(blocksparse(d, (r, r)))
149152
@test sprint(show, a) == sprint(show, Array(a))
150153
@test sprint(show, MIME("text/plain"), a) isa String
151154
@test @constinferred(blocktype(a)) === valtype(d)
@@ -167,7 +170,7 @@ end
167170
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
168171
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
169172
)
170-
a = dev(blocksparse(d, r, r))
173+
a = dev(blocksparse(d, (r, r)))
171174
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
172175
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
173176
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
@@ -194,7 +197,7 @@ end
194197
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
195198
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
196199
)
197-
a = dev(blocksparse(d, r, r))
200+
a = dev(blocksparse(d, (r, r)))
198201
i1 = Block(1)[(1:2) × (1:2)]
199202
i2 = Block(2)[(2:3) × (2:3)]
200203
I = mortar([i1, i2])
@@ -209,7 +212,7 @@ end
209212
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
210213
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
211214
)
212-
a = dev(blocksparse(d, r, r))
215+
a = dev(blocksparse(d, (r, r)))
213216
i1 = Block(1)[(1:2) × (1:2)]
214217
i2 = Block(2)[(2:3) × (2:3)]
215218
I = [i1, i2]
@@ -272,7 +275,9 @@ end
272275
end
273276

274277
# Broken operations
275-
@test_broken a[Block.(1:2), Block(2)]
278+
b = a[Block.(1:2), Block(2)]
279+
@test b[Block(1)] == a[Block(1, 2)]
280+
@test b[Block(2)] == a[Block(2, 2)]
276281

277282
# svd_trunc
278283
dev = adapt(arrayt)
@@ -282,7 +287,7 @@ end
282287
Block(1, 1) => Eye{elt}(2, 2) randn(rng, elt, 2, 2),
283288
Block(2, 2) => Eye{elt}(3, 3) randn(rng, elt, 3, 3),
284289
)
285-
a = @constinferred dev(blocksparse(d, r, r))
290+
a = @constinferred dev(blocksparse(d, (r, r)))
286291
if arrayt === Array
287292
u, s, v = svd_trunc(a; trunc=(; maxrank=6))
288293
u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=5))
@@ -293,10 +298,10 @@ end
293298

294299
@testset "Block deficient" begin
295300
da = Dict(Block(1, 1) => Eye{elt}(2, 2) dev(randn(elt, 2, 2)))
296-
a = @constinferred dev(blocksparse(da, r, r))
301+
a = @constinferred dev(blocksparse(da, (r, r)))
297302

298303
db = Dict(Block(2, 2) => Eye{elt}(3, 3) dev(randn(elt, 3, 3)))
299-
b = @constinferred dev(blocksparse(db, r, r))
304+
b = @constinferred dev(blocksparse(db, (r, r)))
300305

301306
@test Array(a + b) Array(a) + Array(b)
302307
@test Array(2a) 2Array(a)

0 commit comments

Comments
 (0)