diff --git a/Project.toml b/Project.toml index 087e67c..d0c000f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.3.4" +version = "0.3.5" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -34,6 +34,6 @@ GPUArraysCore = "0.2" LinearAlgebra = "1.10" MapBroadcast = "0.1.10" MatrixAlgebraKit = "0.6" -TensorAlgebra = "0.6.2" +TensorAlgebra = "0.6.3" TypeParameterAccessors = "0.4.2" julia = "1.10" diff --git a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl index 921ab96..b23009a 100644 --- a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl +++ b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl @@ -3,7 +3,7 @@ module KroneckerArraysTensorAlgebraExt using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, CartesianProductUnitRange, ⊗, cartesianrange, kroneckerfactors, kroneckerfactortypes using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation, - FusionStyle, matricize, tensor_product_axis, unmatricize + FusionStyle, matricize, tensor_product_axis, trivial_axis, unmatricize struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle a::A @@ -19,14 +19,57 @@ function TensorAlgebra.FusionStyle(A::Type{<:CartesianProductUnitRange}) return KroneckerFusion(FusionStyle.(kroneckerfactortypes(A))...) end +function TensorAlgebra.trivial_axis( + style::KroneckerFusion, side::Val{:codomain}, a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + return trivial_kronecker(style, side, a, axes_codomain, axes_domain) +end +function TensorAlgebra.trivial_axis( + style::KroneckerFusion, side::Val{:domain}, a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + return trivial_kronecker(style, side, a, axes_codomain, axes_domain) +end +function trivial_kronecker( + style::FusionStyle, side::Val, a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + style_a, style_b = kroneckerfactors(style) + a_a, a_b = kroneckerfactors(a) + axes_codomain_a = kroneckerfactors.(axes_codomain, 1) + axes_codomain_b = kroneckerfactors.(axes_codomain, 2) + axes_domain_a = kroneckerfactors.(axes_domain, 1) + axes_domain_b = kroneckerfactors.(axes_domain, 2) + ra = trivial_axis(style_a, side, a_a, axes_codomain_a, axes_domain_a) + rb = trivial_axis(style_b, side, a_b, axes_codomain_b, axes_domain_b) + return cartesianrange(ra, rb) +end + function TensorAlgebra.tensor_product_axis( - style::KroneckerFusion, r1::AbstractUnitRange, r2::AbstractUnitRange + style::KroneckerFusion, side::Val{:codomain}, + r1::AbstractUnitRange, r2::AbstractUnitRange, + ) + return tensor_product_kronecker(style, side, r1, r2) +end +function TensorAlgebra.tensor_product_axis( + style::KroneckerFusion, side::Val{:domain}, + r1::AbstractUnitRange, r2::AbstractUnitRange, + ) + return tensor_product_kronecker(style, side, r1, r2) +end +function tensor_product_kronecker( + style::KroneckerFusion, side::Val, + r1::AbstractUnitRange, r2::AbstractUnitRange, ) style_a, style_b = kroneckerfactors(style) r1a, r1b = kroneckerfactors(r1) r2a, r2b = kroneckerfactors(r2) - ra = tensor_product_axis(style_a, r1a, r2a) - rb = tensor_product_axis(style_b, r1b, r2b) + ra = tensor_product_axis(style_a, side, r1a, r2a) + rb = tensor_product_axis(style_b, side, r1b, r2b) return cartesianrange(ra, rb) end @@ -44,8 +87,7 @@ function TensorAlgebra.matricize( end function unmatricize_kronecker( - style::FusionStyle, - m::AbstractMatrix, + style::FusionStyle, m::AbstractMatrix, axes_codomain::Tuple{Vararg{AbstractUnitRange}}, axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) @@ -60,8 +102,7 @@ function unmatricize_kronecker( return a1 ⊗ a2 end function TensorAlgebra.unmatricize( - style::KroneckerFusion, - m::AbstractMatrix, + style::KroneckerFusion, m::AbstractMatrix, codomain_axes::Tuple{Vararg{AbstractUnitRange}}, domain_axes::Tuple{Vararg{AbstractUnitRange}}, ) diff --git a/test/test_tensoralgebra.jl b/test/test_tensoralgebra.jl index ecbcd8e..fdb0efb 100644 --- a/test/test_tensoralgebra.jl +++ b/test/test_tensoralgebra.jl @@ -1,9 +1,12 @@ -using TensorAlgebra: matricize, tensor_product_axis, unmatricize +using TensorAlgebra: matricize, tensor_product_axis, trivial_axis, unmatricize using KroneckerArrays: ⊗, cartesianrange, kroneckerfactors, unproduct using Test: @test, @testset @testset "TensorAlgebraExt" begin @testset "tensor_product_axis" begin + r = cartesianrange(2, 3) + @test trivial_axis(r) ≡ cartesianrange(1, 1) + r1 = cartesianrange(2, 3) r2 = cartesianrange(4, 5) r = tensor_product_axis(r1, r2) @@ -15,7 +18,8 @@ using Test: @test, @testset @testset "matricize/unmatricize" begin a = randn(2, 2, 2) ⊗ randn(3, 3, 3) m = matricize(a, (1, 2), (3,)) - @test m == matricize(kroneckerfactors(a, 1), (1, 2), (3,)) ⊗ matricize(kroneckerfactors(a, 2), (1, 2), (3,)) + @test m == matricize(kroneckerfactors(a, 1), (1, 2), (3,)) ⊗ + matricize(kroneckerfactors(a, 2), (1, 2), (3,)) @test unmatricize(m, (axes(a, 1), axes(a, 2)), (axes(a, 3),)) == a end end