Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.9"
version = "0.3.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
[compat]
Documenter = "1"
Literate = "2"
KroneckerArrays = "0.2"
KroneckerArrays = "0.3"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"

[compat]
KroneckerArrays = "0.2"
KroneckerArrays = "0.3"
Original file line number Diff line number Diff line change
@@ -1,70 +1,51 @@
module KroneckerArraysBlockSparseArraysExt

using BlockArrays: Block
using BlockSparseArrays: BlockIndexVector, GenericBlockIndex
using KroneckerArrays: CartesianPair, CartesianProduct
function Base.getindex(
b::Block{N},
I::Vararg{Union{CartesianPair, CartesianProduct}, N}
) where {N}
return GenericBlockIndex(b, I)
end
function Base.getindex(b::Block{N}, I::Vararg{CartesianProduct, N}) where {N}
return BlockIndexVector(b, I)
end
using KroneckerArrays: KroneckerArrays, KroneckerArray, KroneckerVector,
CartesianPair, CartesianProduct, CartesianProductUnitRange,
kroneckerfactors, ⊗, isactive, cartesianrange
using BlockArrays: BlockArrays, Block, AbstractBlockedUnitRange, mortar
using BlockSparseArrays: BlockSparseArrays, BlockIndexVector, GenericBlockIndex, ZeroBlocks,
blockrange, eachblockaxis, mortar_axis
using DiagonalArrays: ShapeInitializer

using BlockSparseArrays: BlockSparseArrays, blockrange
using KroneckerArrays: CartesianPair, CartesianProduct, cartesianrange
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianPair})
return blockrange(map(cartesianrange, bs))
end
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct})
return blockrange(map(cartesianrange, bs))
end

using BlockArrays: BlockArrays, mortar
using BlockSparseArrays: blockrange
using KroneckerArrays: CartesianProductUnitRange
Base.getindex(b::Block{N}, I::Vararg{Union{CartesianPair, CartesianProduct}, N}) where {N} =
GenericBlockIndex(b, I)
Base.getindex(b::Block{N}, I::Vararg{CartesianProduct, N}) where {N} =
BlockIndexVector(b, I)

BlockSparseArrays.blockrange(bs::Vector{<:CartesianPair}) = blockrange(map(cartesianrange, bs))
BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct}) = blockrange(map(cartesianrange, bs))

# Makes sure that `mortar` results in a `BlockVector` with the correct
# axes, otherwise the axes would not preserve the Kronecker structure.
# This is helpful when indexing `BlockUnitRange`, for example:
# https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.7.1/src/blockaxis.jl#L540-L547
function BlockArrays.mortar(blocks::AbstractVector{<:CartesianProductUnitRange})
return mortar(blocks, (blockrange(map(Base.axes1, blocks)),))
end
BlockArrays.mortar(blocks::AbstractVector{<:CartesianProductUnitRange}) =
mortar(blocks, (blockrange(map(Base.axes1, blocks)),))

using BlockArrays: AbstractBlockedUnitRange
using BlockSparseArrays: Block, ZeroBlocks, eachblockaxis, mortar_axis
using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, arg1, arg2, isactive

function KroneckerArrays.arg1(r::AbstractBlockedUnitRange)
return mortar_axis(arg1.(eachblockaxis(r)))
end
function KroneckerArrays.arg2(r::AbstractBlockedUnitRange)
return mortar_axis(arg2.(eachblockaxis(r)))
end
KroneckerArrays.kroneckerfactors(r::AbstractBlockedUnitRange, i::Int) =
mortar_axis(kroneckerfactors.(eachblockaxis(r), i))
KroneckerArrays.kroneckerfactors(r::AbstractBlockedUnitRange) =
(kroneckerfactors(r, 1), kroneckerfactors(r, 2))

function block_axes(
ax::NTuple{N, AbstractUnitRange{<:Integer}}, I::Vararg{Block{1}, N}
) where {N}
function block_axes(ax::NTuple{N, AbstractUnitRange{<:Integer}}, I::Vararg{Block{1}, N}) where {N}
return ntuple(N) do d
return only(axes(ax[d][I[d]]))
end
end
function block_axes(ax::NTuple{N, AbstractUnitRange{<:Integer}}, I::Block{N}) where {N}
return block_axes(ax, Tuple(I)...)
end

using DiagonalArrays: ShapeInitializer
block_axes(ax::NTuple{N, AbstractUnitRange{<:Integer}}, I::Block{N}) where {N} =
block_axes(ax, Tuple(I)...)

## TODO: Is this needed?
function Base.getindex(
a::ZeroBlocks{N, KroneckerArray{T, N, A1, A2}}, I::Vararg{Int, N}
) where {T, N, A1 <: AbstractArray{T, N}, A2 <: AbstractArray{T, N}}
ax_a1 = map(arg1, a.parentaxes)
ax_a2 = map(arg2, a.parentaxes)
block_ax_a1 = arg1.(block_axes(a.parentaxes, Block(I)))
block_ax_a2 = arg2.(block_axes(a.parentaxes, Block(I)))
ax_a1 = kroneckerfactors.(a.parentaxes, 1)
ax_a2 = kroneckerfactors.(a.parentaxes, 2)
block_ax_a1 = kroneckerfactors.(block_axes(a.parentaxes, Block(I)), 1)
block_ax_a2 = kroneckerfactors.(block_axes(a.parentaxes, Block(I)), 2)
# TODO: Is this a good definition? It is similar to
# the definition of `similar` and `adapt_structure`.
return if isactive(A1) == isactive(A2)
Expand All @@ -76,10 +57,7 @@ function Base.getindex(
end
end

using BlockSparseArrays: BlockSparseArrays
using KroneckerArrays: KroneckerArrays, KroneckerVector
function BlockSparseArrays.to_truncated_indices(values::KroneckerVector, I)
return KroneckerArrays.to_truncated_indices(values, I)
end
BlockSparseArrays.to_truncated_indices(values::KroneckerVector, I) =
KroneckerArrays.to_truncated_indices(values, I)

end
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
module KroneckerArraysTensorAlgebraExt

using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, ⊗, arg1, arg2
using TensorAlgebra:
TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize, unmatricize
using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, ⊗, kroneckerfactors
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation, FusionStyle,
matricize, unmatricize

struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle
a::A
b::B
end
KroneckerArrays.arg1(style::KroneckerFusion) = style.a
KroneckerArrays.arg2(style::KroneckerFusion) = style.b
function TensorAlgebra.FusionStyle(a::AbstractKroneckerArray)
return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a)))
end
KroneckerArrays.kroneckerfactors(style::KroneckerFusion) = (style.a, style.b)
KroneckerArrays.kroneckerfactortypes(::Type{KroneckerFusion{A, B}}) where {A, B} = (A, B)

TensorAlgebra.FusionStyle(a::AbstractKroneckerArray) = KroneckerFusion(FusionStyle.(kroneckerfactors(a))...)
function matricize_kronecker(
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
)
return matricize(arg1(style), arg1(a), biperm) ⊗ matricize(arg2(style), arg2(a), biperm)
return matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), biperm) ⊗
matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), biperm)
end
function TensorAlgebra.matricize(
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
Expand All @@ -32,8 +32,8 @@ function TensorAlgebra.matricize(
return matricize_kronecker(style, a, biperm)
end
function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax)
return unmatricize(arg1(style), arg1(a), arg1.(ax)) ⊗
unmatricize(arg2(style), arg2(a), arg2.(ax))
return unmatricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), kroneckerfactors.(ax, 1)) ⊗
unmatricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), kroneckerfactors.(ax, 2))
end
function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax)
return unmatricize_kronecker(style, a, ax)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
module KroneckerArraysTensorProductsExt

using KroneckerArrays: CartesianProductOneTo, ×, arg1, arg2, cartesianrange, unproduct
using TensorProducts: TensorProducts, tensor_product
using KroneckerArrays: CartesianProductOneTo, kroneckerfactors, cartesianrange, unproduct

function TensorProducts.tensor_product(a1::CartesianProductOneTo, a2::CartesianProductOneTo)
prod = tensor_product(arg1(a1), arg1(a2)) × tensor_product(arg2(a1), arg2(a2))
range = tensor_product(unproduct(a1), unproduct(a2))
return cartesianrange(prod, range)
return cartesianrange(
tensor_product(kroneckerfactors(a1, 1), kroneckerfactors(a2, 1)),
tensor_product(kroneckerfactors(a1, 2), kroneckerfactors(a2, 2)),
tensor_product(unproduct(a1), unproduct(a2))
)
end

end
45 changes: 45 additions & 0 deletions src/KroneckerArrays.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,52 @@
module KroneckerArrays

export kroneckerfactors, kroneckerfactortypes
export times, ×, cartesianproduct, cartesianrange, unproduct
export ⊗, ×

# Imports
# -------
import Base.Broadcast as BC
using LinearAlgebra: LinearAlgebra, Diagonal, diag, isdiag
using DiagonalArrays: DiagonalArrays
using DerivableInterfaces: DerivableInterfaces
using MapBroadcast: MapBroadcast, MapFunction, LinearCombination, Summed
using GPUArraysCore: GPUArraysCore
using Adapt: Adapt

# Interfaces
# ----------
@doc """
kroneckerfactors(x) -> Tuple
kroneckerfactors(x, i) = kroneckerfactors(x)[i]

Extract the factors of `x`, where `x` is an object that represents a lazily composed product type.
""" kroneckerfactors
# note: this is `Int` instead of `Integer` to avoid ambiguities downstream
@inline kroneckerfactors(x, i::Int) = kroneckerfactors(x)[i]

@doc """
kroneckerfactortypes(x) -> Tuple
kroneckerfactortypes(x, i) = kroneckerfactortypes(x)[i]

Extract the types of the factors of `x`, where `x` is an object or type that represents a lazily composed product type.
""" kroneckerfactortypes
# note: this is `Int` instead of `Integer` to avoid ambiguities downstream
@inline kroneckerfactortypes(x, i::Int) = kroneckerfactortypes(x)[i]
kroneckerfactortypes(x) = kroneckerfactortypes(typeof(x))
kroneckerfactortypes(T::Type) = throw(MethodError(kroneckerfactortypes, (T,)))

@doc """
⊗(args...)
otimes(args...)

Construct an object that represents the Kronecker product of the provided `args`.
""" (⊗)
function ⊗(a, b) end
const otimes = ⊗ # non-unicode alternative

# Includes
# --------
include("cartesianproduct.jl")
include("kroneckerarray.jl")
include("linearalgebra.jl")
Expand Down
Loading
Loading