Skip to content

Commit 7bf3e18

Browse files
authored
Rewrite MatrixAlgebraKit factorization definitions (#43)
1 parent 7433652 commit 7bf3e18

22 files changed

+1144
-1495
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.31"
4+
version = "0.2.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -12,6 +12,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1414
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
15+
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1516

1617
[weakdeps]
1718
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
@@ -29,12 +30,13 @@ Adapt = "4.3"
2930
BlockArrays = "1.6"
3031
BlockSparseArrays = "0.9, 0.10.3"
3132
DerivableInterfaces = "0.5.3"
32-
DiagonalArrays = "0.3.11"
33+
DiagonalArrays = "0.3.19"
3334
FillArrays = "1.13"
3435
GPUArraysCore = "0.2"
3536
LinearAlgebra = "1.10"
3637
MapBroadcast = "0.1.10"
3738
MatrixAlgebraKit = "0.2, 0.3"
3839
TensorAlgebra = "0.3.10"
3940
TensorProducts = "0.1.7"
41+
TypeParameterAccessors = "0.4.2"
4042
julia = "1.10"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
66
[compat]
77
Documenter = "1"
88
Literate = "2"
9-
KroneckerArrays = "0.1"
9+
KroneckerArrays = "0.2"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33

44
[compat]
5-
KroneckerArrays = "0.1"
5+
KroneckerArrays = "0.2"

ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ end
3636

3737
using BlockArrays: AbstractBlockedUnitRange
3838
using BlockSparseArrays: Block, ZeroBlocks, eachblockaxis, mortar_axis
39-
using KroneckerArrays: KroneckerArrays, KroneckerArray, , arg1, arg2, _similar
40-
using BlockSparseArrays.TypeParameterAccessors: unwrap_array_type
39+
using KroneckerArrays: KroneckerArrays, KroneckerArray, , arg1, arg2, isactive
4140

4241
function KroneckerArrays.arg1(r::AbstractBlockedUnitRange)
4342
return mortar_axis(arg1.(eachblockaxis(r)))
@@ -57,30 +56,25 @@ function block_axes(ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Block{N}) whe
5756
return block_axes(ax, Tuple(I)...)
5857
end
5958

59+
using DiagonalArrays: ShapeInitializer
60+
6061
## TODO: Is this needed?
6162
function Base.getindex(
62-
a::ZeroBlocks{N,KroneckerArray{T,N,A,B}}, I::Vararg{Int,N}
63-
) where {T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}}
63+
a::ZeroBlocks{N,KroneckerArray{T,N,A1,A2}}, I::Vararg{Int,N}
64+
) where {T,N,A1<:AbstractArray{T,N},A2<:AbstractArray{T,N}}
6465
ax_a1 = map(arg1, a.parentaxes)
6566
ax_a2 = map(arg2, a.parentaxes)
66-
# TODO: Instead of mutability, maybe have a trait like
67-
# `isstructural` or `isdata`.
68-
ismut1 = ismutabletype(unwrap_array_type(A))
69-
ismut2 = ismutabletype(unwrap_array_type(B))
70-
(ismut1 || ismut2) || error("Can't get zero block.")
71-
a1 = if ismut1
72-
ZeroBlocks{N,A}(ax_a1)[I...]
73-
else
74-
block_ax_a1 = arg1.(block_axes(a.parentaxes, Block(I)))
75-
_similar(A, block_ax_a1)
76-
end
77-
a2 = if ismut2
78-
ZeroBlocks{N,B}(ax_a2)[I...]
79-
else
80-
block_ax_a2 = arg2.(block_axes(a.parentaxes, Block(I)))
81-
a2 = _similar(B, block_ax_a2)
67+
block_ax_a1 = arg1.(block_axes(a.parentaxes, Block(I)))
68+
block_ax_a2 = arg2.(block_axes(a.parentaxes, Block(I)))
69+
# TODO: Is this a good definition? It is similar to
70+
# the definition of `similar` and `adapt_structure`.
71+
return if isactive(A1) == isactive(A2)
72+
ZeroBlocks{N,A1}(ax_a1)[I...] ZeroBlocks{N,A2}(ax_a2)[I...]
73+
elseif isactive(A1)
74+
ZeroBlocks{N,A1}(ax_a1)[I...] A2(ShapeInitializer(), block_ax_a2)
75+
elseif isactive(A2)
76+
A1(ShapeInitializer(), block_ax_a1) ZeroBlocks{N,A2}(ax_a2)[I...]
8277
end
83-
return a1 a2
8478
end
8579

8680
using BlockSparseArrays: BlockSparseArrays

src/KroneckerArrays.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@ include("cartesianproduct.jl")
66
include("kroneckerarray.jl")
77
include("linearalgebra.jl")
88
include("matrixalgebrakit.jl")
9-
include("fillarrays/kroneckerarray.jl")
10-
include("fillarrays/linearalgebra.jl")
11-
include("fillarrays/matrixalgebrakit.jl")
12-
include("fillarrays/matrixalgebrakit_truncate.jl")
9+
include("fillarrays.jl")
1310

1411
end

src/cartesianproduct.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
struct CartesianPair{A,B}
2-
a::A
3-
b::B
1+
struct CartesianPair{A1,A2}
2+
arg1::A1
3+
arg2::A2
44
end
5-
arguments(a::CartesianPair) = (a.a, a.b)
5+
arguments(a::CartesianPair) = (arg1(a), arg2(a))
66
arguments(a::CartesianPair, n::Int) = arguments(a)[n]
77

8-
arg1(a::CartesianPair) = a.a
9-
arg2(a::CartesianPair) = a.b
8+
arg1(a::CartesianPair) = getfield(a, :arg1)
9+
arg2(a::CartesianPair) = getfield(a, :arg2)
1010

11-
×(a, b) = CartesianPair(a, b)
11+
×(a1, a2) = CartesianPair(a1, a2)
1212

1313
function Base.show(io::IO, a::CartesianPair)
14-
print(io, a.a, " × ", a.b)
14+
print(io, arg1(a), " × ", arg2(a))
1515
return nothing
1616
end
1717

@@ -20,25 +20,25 @@ struct CartesianProduct{TA,TB,A<:AbstractVector{TA},B<:AbstractVector{TB}} <:
2020
a::A
2121
b::B
2222
end
23-
arguments(a::CartesianProduct) = (a.a, a.b)
23+
arguments(a::CartesianProduct) = (arg1(a), arg2(a))
2424
arguments(a::CartesianProduct, n::Int) = arguments(a)[n]
2525

26-
arg1(a::CartesianProduct) = a.a
27-
arg2(a::CartesianProduct) = a.b
26+
arg1(a::CartesianProduct) = getfield(a, :a)
27+
arg2(a::CartesianProduct) = getfield(a, :b)
2828

2929
Base.copy(a::CartesianProduct) = copy(arg1(a)) × copy(arg2(a))
3030

3131
function Base.show(io::IO, a::CartesianProduct)
32-
print(io, a.a, " × ", a.b)
32+
print(io, arg1(a), " × ", arg2(a))
3333
return nothing
3434
end
3535
function Base.show(io::IO, ::MIME"text/plain", a::CartesianProduct)
3636
show(io, a)
3737
return nothing
3838
end
3939

40-
×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b)
41-
Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
40+
×(a1::AbstractVector, a2::AbstractVector) = CartesianProduct(a1, a2)
41+
Base.length(a::CartesianProduct) = length(arg1(a)) * length(arg2(a))
4242
Base.size(a::CartesianProduct) = (length(a),)
4343

4444
function Base.getindex(a::CartesianProduct, i::CartesianProduct)
@@ -118,12 +118,12 @@ end
118118
function CartesianProductUnitRange(p::CartesianProduct)
119119
return CartesianProductUnitRange(p, Base.OneTo(length(p)))
120120
end
121-
function CartesianProductUnitRange(a, b)
122-
return CartesianProductUnitRange(a × b)
121+
function CartesianProductUnitRange(a1, a2)
122+
return CartesianProductUnitRange(a1 × a2)
123123
end
124124
to_product_indices(a::AbstractVector) = a
125125
to_product_indices(i::Integer) = Base.OneTo(i)
126-
cartesianrange(a, b) = cartesianrange(to_product_indices(a) × to_product_indices(b))
126+
cartesianrange(a1, a2) = cartesianrange(to_product_indices(a1) × to_product_indices(a2))
127127
function cartesianrange(p::CartesianPair)
128128
p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p))
129129
return cartesianrange(p′)

src/fillarrays.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
using FillArrays: FillArrays, Ones, Zeros
2+
function FillArrays.fillsimilar(
3+
a::Zeros{T},
4+
ax::Tuple{
5+
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
6+
},
7+
) where {T}
8+
return Zeros{T}(arg1.(ax)) Zeros{T}(arg2.(ax))
9+
end
10+
11+
# Simplification rules similar to those for FillArrays.jl:
12+
# https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl
13+
using FillArrays: Zeros
14+
function Base.broadcasted(
15+
style::KroneckerStyle,
16+
::typeof(+),
17+
a::KroneckerArray,
18+
b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
19+
)
20+
# TODO: Promote the element types.
21+
return a
22+
end
23+
function Base.broadcasted(
24+
style::KroneckerStyle,
25+
::typeof(+),
26+
a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
27+
b::KroneckerArray,
28+
)
29+
# TODO: Promote the element types.
30+
return b
31+
end
32+
function Base.broadcasted(
33+
style::KroneckerStyle,
34+
::typeof(+),
35+
a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
36+
b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
37+
)
38+
# TODO: Promote the element types and axes.
39+
return b
40+
end
41+
function Base.broadcasted(
42+
style::KroneckerStyle,
43+
::typeof(-),
44+
a::KroneckerArray,
45+
b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
46+
)
47+
# TODO: Promote the element types.
48+
return a
49+
end
50+
function Base.broadcasted(
51+
style::KroneckerStyle,
52+
::typeof(-),
53+
a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
54+
b::KroneckerArray,
55+
)
56+
# TODO: Promote the element types.
57+
# TODO: Return `broadcasted(-, b)`.
58+
return -b
59+
end
60+
function Base.broadcasted(
61+
style::KroneckerStyle,
62+
::typeof(-),
63+
a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
64+
b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
65+
)
66+
# TODO: Promote the element types and axes.
67+
return b
68+
end

0 commit comments

Comments
 (0)