Skip to content

Commit adbd300

Browse files
authored
Interface improvements (#56)
1 parent bbf9bff commit adbd300

20 files changed

+1054
-1112
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.9"
4+
version = "0.3.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
66
[compat]
77
Documenter = "1"
88
Literate = "2"
9-
KroneckerArrays = "0.2"
9+
KroneckerArrays = "0.3"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33

44
[compat]
5-
KroneckerArrays = "0.2"
5+
KroneckerArrays = "0.3"
Lines changed: 30 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,51 @@
11
module KroneckerArraysBlockSparseArraysExt
22

3-
using BlockArrays: Block
4-
using BlockSparseArrays: BlockIndexVector, GenericBlockIndex
5-
using KroneckerArrays: CartesianPair, CartesianProduct
6-
function Base.getindex(
7-
b::Block{N},
8-
I::Vararg{Union{CartesianPair, CartesianProduct}, N}
9-
) where {N}
10-
return GenericBlockIndex(b, I)
11-
end
12-
function Base.getindex(b::Block{N}, I::Vararg{CartesianProduct, N}) where {N}
13-
return BlockIndexVector(b, I)
14-
end
3+
using KroneckerArrays: KroneckerArrays, KroneckerArray, KroneckerVector,
4+
CartesianPair, CartesianProduct, CartesianProductUnitRange,
5+
kroneckerfactors, , isactive, cartesianrange
6+
using BlockArrays: BlockArrays, Block, AbstractBlockedUnitRange, mortar
7+
using BlockSparseArrays: BlockSparseArrays, BlockIndexVector, GenericBlockIndex, ZeroBlocks,
8+
blockrange, eachblockaxis, mortar_axis
9+
using DiagonalArrays: ShapeInitializer
1510

16-
using BlockSparseArrays: BlockSparseArrays, blockrange
17-
using KroneckerArrays: CartesianPair, CartesianProduct, cartesianrange
18-
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianPair})
19-
return blockrange(map(cartesianrange, bs))
20-
end
21-
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct})
22-
return blockrange(map(cartesianrange, bs))
23-
end
2411

25-
using BlockArrays: BlockArrays, mortar
26-
using BlockSparseArrays: blockrange
27-
using KroneckerArrays: CartesianProductUnitRange
12+
Base.getindex(b::Block{N}, I::Vararg{Union{CartesianPair, CartesianProduct}, N}) where {N} =
13+
GenericBlockIndex(b, I)
14+
Base.getindex(b::Block{N}, I::Vararg{CartesianProduct, N}) where {N} =
15+
BlockIndexVector(b, I)
16+
17+
BlockSparseArrays.blockrange(bs::Vector{<:CartesianPair}) = blockrange(map(cartesianrange, bs))
18+
BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct}) = blockrange(map(cartesianrange, bs))
19+
2820
# Makes sure that `mortar` results in a `BlockVector` with the correct
2921
# axes, otherwise the axes would not preserve the Kronecker structure.
3022
# This is helpful when indexing `BlockUnitRange`, for example:
3123
# https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.7.1/src/blockaxis.jl#L540-L547
32-
function BlockArrays.mortar(blocks::AbstractVector{<:CartesianProductUnitRange})
33-
return mortar(blocks, (blockrange(map(Base.axes1, blocks)),))
34-
end
24+
BlockArrays.mortar(blocks::AbstractVector{<:CartesianProductUnitRange}) =
25+
mortar(blocks, (blockrange(map(Base.axes1, blocks)),))
3526

36-
using BlockArrays: AbstractBlockedUnitRange
37-
using BlockSparseArrays: Block, ZeroBlocks, eachblockaxis, mortar_axis
38-
using KroneckerArrays: KroneckerArrays, KroneckerArray, , arg1, arg2, isactive
3927

40-
function KroneckerArrays.arg1(r::AbstractBlockedUnitRange)
41-
return mortar_axis(arg1.(eachblockaxis(r)))
42-
end
43-
function KroneckerArrays.arg2(r::AbstractBlockedUnitRange)
44-
return mortar_axis(arg2.(eachblockaxis(r)))
45-
end
28+
KroneckerArrays.kroneckerfactors(r::AbstractBlockedUnitRange, i::Int) =
29+
mortar_axis(kroneckerfactors.(eachblockaxis(r), i))
30+
KroneckerArrays.kroneckerfactors(r::AbstractBlockedUnitRange) =
31+
(kroneckerfactors(r, 1), kroneckerfactors(r, 2))
4632

47-
function block_axes(
48-
ax::NTuple{N, AbstractUnitRange{<:Integer}}, I::Vararg{Block{1}, N}
49-
) where {N}
33+
function block_axes(ax::NTuple{N, AbstractUnitRange{<:Integer}}, I::Vararg{Block{1}, N}) where {N}
5034
return ntuple(N) do d
5135
return only(axes(ax[d][I[d]]))
5236
end
5337
end
54-
function block_axes(ax::NTuple{N, AbstractUnitRange{<:Integer}}, I::Block{N}) where {N}
55-
return block_axes(ax, Tuple(I)...)
56-
end
57-
58-
using DiagonalArrays: ShapeInitializer
38+
block_axes(ax::NTuple{N, AbstractUnitRange{<:Integer}}, I::Block{N}) where {N} =
39+
block_axes(ax, Tuple(I)...)
5940

6041
## TODO: Is this needed?
6142
function Base.getindex(
6243
a::ZeroBlocks{N, KroneckerArray{T, N, A1, A2}}, I::Vararg{Int, N}
6344
) where {T, N, A1 <: AbstractArray{T, N}, A2 <: AbstractArray{T, N}}
64-
ax_a1 = map(arg1, a.parentaxes)
65-
ax_a2 = map(arg2, a.parentaxes)
66-
block_ax_a1 = arg1.(block_axes(a.parentaxes, Block(I)))
67-
block_ax_a2 = arg2.(block_axes(a.parentaxes, Block(I)))
45+
ax_a1 = kroneckerfactors.(a.parentaxes, 1)
46+
ax_a2 = kroneckerfactors.(a.parentaxes, 2)
47+
block_ax_a1 = kroneckerfactors.(block_axes(a.parentaxes, Block(I)), 1)
48+
block_ax_a2 = kroneckerfactors.(block_axes(a.parentaxes, Block(I)), 2)
6849
# TODO: Is this a good definition? It is similar to
6950
# the definition of `similar` and `adapt_structure`.
7051
return if isactive(A1) == isactive(A2)
@@ -76,10 +57,7 @@ function Base.getindex(
7657
end
7758
end
7859

79-
using BlockSparseArrays: BlockSparseArrays
80-
using KroneckerArrays: KroneckerArrays, KroneckerVector
81-
function BlockSparseArrays.to_truncated_indices(values::KroneckerVector, I)
82-
return KroneckerArrays.to_truncated_indices(values, I)
83-
end
60+
BlockSparseArrays.to_truncated_indices(values::KroneckerVector, I) =
61+
KroneckerArrays.to_truncated_indices(values, I)
8462

8563
end

ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
module KroneckerArraysTensorAlgebraExt
22

3-
using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, , arg1, arg2
4-
using TensorAlgebra:
5-
TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize, unmatricize
3+
using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, , kroneckerfactors
4+
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation, FusionStyle,
5+
matricize, unmatricize
66

77
struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle
88
a::A
99
b::B
1010
end
11-
KroneckerArrays.arg1(style::KroneckerFusion) = style.a
12-
KroneckerArrays.arg2(style::KroneckerFusion) = style.b
13-
function TensorAlgebra.FusionStyle(a::AbstractKroneckerArray)
14-
return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a)))
15-
end
11+
KroneckerArrays.kroneckerfactors(style::KroneckerFusion) = (style.a, style.b)
12+
KroneckerArrays.kroneckerfactortypes(::Type{KroneckerFusion{A, B}}) where {A, B} = (A, B)
13+
14+
TensorAlgebra.FusionStyle(a::AbstractKroneckerArray) = KroneckerFusion(FusionStyle.(kroneckerfactors(a))...)
1615
function matricize_kronecker(
1716
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
1817
)
19-
return matricize(arg1(style), arg1(a), biperm) matricize(arg2(style), arg2(a), biperm)
18+
return matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), biperm)
19+
matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), biperm)
2020
end
2121
function TensorAlgebra.matricize(
2222
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
@@ -32,8 +32,8 @@ function TensorAlgebra.matricize(
3232
return matricize_kronecker(style, a, biperm)
3333
end
3434
function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax)
35-
return unmatricize(arg1(style), arg1(a), arg1.(ax))
36-
unmatricize(arg2(style), arg2(a), arg2.(ax))
35+
return unmatricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), kroneckerfactors.(ax, 1))
36+
unmatricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), kroneckerfactors.(ax, 2))
3737
end
3838
function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax)
3939
return unmatricize_kronecker(style, a, ax)
Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
module KroneckerArraysTensorProductsExt
22

3-
using KroneckerArrays: CartesianProductOneTo, ×, arg1, arg2, cartesianrange, unproduct
43
using TensorProducts: TensorProducts, tensor_product
4+
using KroneckerArrays: CartesianProductOneTo, kroneckerfactors, cartesianrange, unproduct
5+
56
function TensorProducts.tensor_product(a1::CartesianProductOneTo, a2::CartesianProductOneTo)
6-
prod = tensor_product(arg1(a1), arg1(a2)) × tensor_product(arg2(a1), arg2(a2))
7-
range = tensor_product(unproduct(a1), unproduct(a2))
8-
return cartesianrange(prod, range)
7+
return cartesianrange(
8+
tensor_product(kroneckerfactors(a1, 1), kroneckerfactors(a2, 1)),
9+
tensor_product(kroneckerfactors(a1, 2), kroneckerfactors(a2, 2)),
10+
tensor_product(unproduct(a1), unproduct(a2))
11+
)
912
end
1013

1114
end

src/KroneckerArrays.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,52 @@
11
module KroneckerArrays
22

3+
export kroneckerfactors, kroneckerfactortypes
4+
export times, ×, cartesianproduct, cartesianrange, unproduct
35
export , ×
46

7+
# Imports
8+
# -------
9+
import Base.Broadcast as BC
10+
using LinearAlgebra: LinearAlgebra, Diagonal, diag, isdiag
11+
using DiagonalArrays: DiagonalArrays
12+
using DerivableInterfaces: DerivableInterfaces
13+
using MapBroadcast: MapBroadcast, MapFunction, LinearCombination, Summed
14+
using GPUArraysCore: GPUArraysCore
15+
using Adapt: Adapt
16+
17+
# Interfaces
18+
# ----------
19+
@doc """
20+
kroneckerfactors(x) -> Tuple
21+
kroneckerfactors(x, i) = kroneckerfactors(x)[i]
22+
23+
Extract the factors of `x`, where `x` is an object that represents a lazily composed product type.
24+
""" kroneckerfactors
25+
# note: this is `Int` instead of `Integer` to avoid ambiguities downstream
26+
@inline kroneckerfactors(x, i::Int) = kroneckerfactors(x)[i]
27+
28+
@doc """
29+
kroneckerfactortypes(x) -> Tuple
30+
kroneckerfactortypes(x, i) = kroneckerfactortypes(x)[i]
31+
32+
Extract the types of the factors of `x`, where `x` is an object or type that represents a lazily composed product type.
33+
""" kroneckerfactortypes
34+
# note: this is `Int` instead of `Integer` to avoid ambiguities downstream
35+
@inline kroneckerfactortypes(x, i::Int) = kroneckerfactortypes(x)[i]
36+
kroneckerfactortypes(x) = kroneckerfactortypes(typeof(x))
37+
kroneckerfactortypes(T::Type) = throw(MethodError(kroneckerfactortypes, (T,)))
38+
39+
@doc """
40+
⊗(args...)
41+
otimes(args...)
42+
43+
Construct an object that represents the Kronecker product of the provided `args`.
44+
""" ()
45+
function (a, b) end
46+
const otimes = # non-unicode alternative
47+
48+
# Includes
49+
# --------
550
include("cartesianproduct.jl")
651
include("kroneckerarray.jl")
752
include("linearalgebra.jl")

0 commit comments

Comments
 (0)