diff --git a/Project.toml b/Project.toml index 3a2f50b..c13cd4e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.27" +version = "0.1.28" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -16,22 +16,25 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" [weakdeps] BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" [extensions] KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"] +KroneckerArraysTensorAlgebraExt = "TensorAlgebra" KroneckerArraysTensorProductsExt = "TensorProducts" [compat] Adapt = "4.3" BlockArrays = "1.6" BlockSparseArrays = "0.9" -DerivableInterfaces = "0.5" -DiagonalArrays = "0.3.5" +DerivableInterfaces = "0.5.3" +DiagonalArrays = "0.3.11" FillArrays = "1.13" GPUArraysCore = "0.2" LinearAlgebra = "1.10" MapBroadcast = "0.1.9" MatrixAlgebraKit = "0.2" +TensorAlgebra = "0.3.10" TensorProducts = "0.1.7" julia = "1.10" diff --git a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl index e67d3aa..624b269 100644 --- a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl +++ b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl @@ -39,7 +39,7 @@ using KroneckerArrays: _similar function KroneckerArrays.arg1(r::AbstractBlockedUnitRange) - return mortar_axis(arg2.(eachblockaxis(r))) + return mortar_axis(arg1.(eachblockaxis(r))) end function KroneckerArrays.arg2(r::AbstractBlockedUnitRange) return mortar_axis(arg2.(eachblockaxis(r))) @@ -56,15 +56,14 @@ function block_axes(ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Block{N}) whe return block_axes(ax, Tuple(I)...) end +## TODO: Is this needed? function Base.getindex( a::ZeroBlocks{2,KroneckerMatrix{T,A,B}}, I::Vararg{Int,2} ) where {T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} - ax_a1 = arg1.(a.parentaxes) + ax_a1 = map(arg1, a.parentaxes) a1 = ZeroBlocks{2,A}(ax_a1)[I...] - - ax_a2 = arg2.(a.parentaxes) + ax_a2 = map(arg2, a.parentaxes) a2 = ZeroBlocks{2,B}(ax_a2)[I...] - return a1 ⊗ a2 end function Base.getindex( diff --git a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl new file mode 100644 index 0000000..2969ea5 --- /dev/null +++ b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl @@ -0,0 +1,42 @@ +module KroneckerArraysTensorAlgebraExt + +using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, arg1, arg2 +using TensorAlgebra: + TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize, unmatricize + +struct KroneckerFusion{A<:FusionStyle,B<:FusionStyle} <: FusionStyle + a::A + b::B +end +KroneckerArrays.arg1(style::KroneckerFusion) = style.a +KroneckerArrays.arg2(style::KroneckerFusion) = style.b +function TensorAlgebra.FusionStyle(a::KroneckerArray) + return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a))) +end +function matricize_kronecker( + style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} +) + return matricize(arg1(style), arg1(a), biperm) ⊗ matricize(arg2(style), arg2(a), biperm) +end +function TensorAlgebra.matricize( + style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2} +) + return matricize_kronecker(style, a, biperm) +end +# Fix ambiguity error. +# TODO: Investigate rewriting the logic in `TensorAlgebra.jl` to avoid this. +using TensorAlgebra: BlockedTrivialPermutation, unmatricize +function TensorAlgebra.matricize( + style::KroneckerFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2} +) + return matricize_kronecker(style, a, biperm) +end +function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax) + return unmatricize(arg1(style), arg1(a), arg1.(ax)) ⊗ + unmatricize(arg2(style), arg2(a), arg2.(ax)) +end +function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax) + return unmatricize_kronecker(style, a, ax) +end + +end diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index 06d54f3..c939dfa 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -98,6 +98,12 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range) arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a)) arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a)) +function Base.getindex(a::CartesianProductUnitRange, i::CartesianProductUnitRange) + prod = cartesianproduct(a)[cartesianproduct(i)] + range = unproduct(a)[unproduct(i)] + return cartesianrange(prod, range) +end + function Base.show(io::IO, a::CartesianProductUnitRange) show(io, unproduct(a)) return nothing diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index 2298039..0faeb48 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -1,4 +1,4 @@ -using FillArrays: FillArrays, Zeros +using FillArrays: FillArrays, Ones, Zeros function FillArrays.fillsimilar( a::Zeros{T}, ax::Tuple{ @@ -21,6 +21,11 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} +using DiagonalArrays: Delta +const DeltaKronecker{T,N,A<:Delta{T,N},B<:AbstractArray{T,N}} = KroneckerArray{T,N,A,B} +const KroneckerDelta{T,N,A<:AbstractArray{T,N},B<:Delta{T,N}} = KroneckerArray{T,N,A,B} +const DeltaDelta{T,N,A<:Delta{T,N},B<:Delta{T,N}} = KroneckerArray{T,N,A,B} + _getindex(a::Eye, I1::Colon, I2::Colon) = a _getindex(a::Eye, I1::Base.Slice, I2::Base.Slice) = a _getindex(a::Eye, I1::Base.Slice, I2::Colon) = a @@ -30,15 +35,23 @@ _view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a _view(a::Eye, I1::Base.Slice, I2::Colon) = a _view(a::Eye, I1::Colon, I2::Base.Slice) = a +function _getindex(a::Delta, I1::Union{Colon,Base.Slice}, Irest::Union{Colon,Base.Slice}...) + return a +end +function _view(a::Delta, I1::Union{Colon,Base.Slice}, Irest::Union{Colon,Base.Slice}...) + return a +end + # Like `adapt` but preserves `Eye`. _adapt(to, a::Eye) = a +_adapt(to, a::Delta) = a # Allows customizing for `FillArrays.Eye`. function _convert(::Type{AbstractArray{T}}, a::RectDiagonal) where {T} - _convert(AbstractMatrix{T}, a) + return _convert(AbstractMatrix{T}, a) end function _convert(::Type{AbstractMatrix{T}}, a::RectDiagonal) where {T} - RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a)) + return RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a)) end # Like `similar` but preserves `Eye`, `Ones`, etc. @@ -61,8 +74,33 @@ function _similar(arrayt::Type{<:SquareEye}, axs::NTuple{2,AbstractUnitRange}) return Eye{eltype(arrayt)}((only(unique(axs)),)) end -# Like `copy` but preserves `Eye`. +function _similar(a::Delta, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}}) + return Delta{elt}(axs) +end +function _similar(arrayt::Type{<:Delta}, axs::Tuple{Vararg{AbstractUnitRange}}) + return Delta{eltype(arrayt)}(axs) +end + +# Like `copy` but preserves `Eye`/`Delta`. _copy(a::Eye) = a +_copy(a::Delta) = a + +function _copyto!!(dest::Eye{<:Any,N}, src::Eye{<:Any,N}) where {N} + size(dest) == size(src) || + throw(ArgumentError("Sizes do not match: $(size(dest)) != $(size(src)).")) + return dest +end +function _copyto!!(dest::Delta{<:Any,N}, src::Delta{<:Any,N}) where {N} + size(dest) == size(src) || + throw(ArgumentError("Sizes do not match: $(size(dest)) != $(size(src)).")) + return dest +end + +function _permutedims!!(dest::Delta, src::Delta, perm) + Base.PermutedDimsArrays.genperm(axes(src), perm) == axes(dest) || + throw(ArgumentError("Permuted axes do not match.")) + return dest +end using Base.Broadcast: AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted @@ -75,10 +113,16 @@ end Base.BroadcastStyle(style1::EyeStyle, style2::EyeStyle) = EyeStyle() Base.BroadcastStyle(style1::EyeStyle, style2::DefaultArrayStyle) = style2 +function _copyto!!(dest::Eye, src::Broadcasted{<:EyeStyle,<:Any,typeof(identity)}) + axes(dest) == axes(src) || error("Dimension mismatch.") + return dest +end + function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type) return Eye{elt}(axes(bc)) end +# TODO: Define in terms of `_copyto!!` that is called on each argument. function Base.copyto!(dest::EyeKronecker, a::Sum{<:KroneckerStyle{<:Any,EyeStyle()}}) dest2 = arg2(dest) f = LinearCombination(a) @@ -99,6 +143,47 @@ function Base.copyto!(dest::EyeEye, a::Sum{<:KroneckerStyle{<:Any,EyeStyle(),Eye return error("Can't write in-place to `Eye ⊗ Eye`.") end +struct DeltaStyle{N} <: AbstractArrayStyle{N} end +DeltaStyle(::Val{N}) where {N} = DeltaStyle{N}() +DeltaStyle{M}(::Val{N}) where {M,N} = DeltaStyle{N}() +function _BroadcastStyle(A::Type{<:Delta}) + return DeltaStyle{ndims(A)}() +end +Base.BroadcastStyle(style1::DeltaStyle, style2::DeltaStyle) = DeltaStyle() +Base.BroadcastStyle(style1::DeltaStyle, style2::DefaultArrayStyle) = style2 + +function _copyto!!(dest::Delta, src::Broadcasted{<:DeltaStyle,<:Any,typeof(identity)}) + axes(dest) == axes(src) || error("Dimension mismatch.") + return dest +end + +function Base.similar(bc::Broadcasted{<:DeltaStyle}, elt::Type) + return Delta{elt}(axes(bc)) +end + +# TODO: Dispatch on `DeltaStyle`. +function Base.copyto!(dest::DeltaKronecker, a::Sum{<:KroneckerStyle}) + dest2 = arg2(dest) + f = LinearCombination(a) + args = arguments(a) + arg2s = arg2.(args) + dest2 .= f.(arg2s...) + return dest +end +# TODO: Dispatch on `DeltaStyle`. +function Base.copyto!(dest::KroneckerDelta, a::Sum{<:KroneckerStyle}) + dest1 = arg1(dest) + f = LinearCombination(a) + args = arguments(a) + arg1s = arg1.(args) + dest1 .= f.(arg1s...) + return dest +end +# TODO: Dispatch on `DeltaStyle`. +function Base.copyto!(dest::DeltaDelta, a::Sum{<:KroneckerStyle}) + return error("Can't write in-place to `Delta ⊗ Delta`.") +end + # Simplification rules similar to those for FillArrays.jl: # https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl using FillArrays: Zeros diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 6b492ef..9d30a08 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -43,9 +43,26 @@ _copy(a::AbstractArray) = copy(a) function Base.copy(a::KroneckerArray) return _copy(arg1(a)) ⊗ _copy(arg2(a)) end -function Base.copyto!(dest::KroneckerArray, src::KroneckerArray) - copyto!(arg1(dest), arg1(src)) - copyto!(arg2(dest), arg2(src)) + +# Allows extra customization, like for `FillArrays.Eye`. +function _copyto!!(dest::AbstractArray{<:Any,N}, src::AbstractArray{<:Any,N}) where {N} + copyto!(dest, src) + return dest +end +function _copyto!!(dest::AbstractArray, src::Broadcasted) + copyto!(dest, src) + return dest +end + +function Base.copyto!(dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N}) where {N} + return copyto!_kronecker(dest, src) +end +function copyto!_kronecker( + dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N} +) where {N} + # TODO: Check if neither argument is mutated and if so error. + _copyto!!(arg1(dest), arg1(src)) + _copyto!!(arg2(dest), arg2(src)) return dest end @@ -110,6 +127,23 @@ function Base.similar( return similar(promote_type(A, B), sz) end +function _permutedims!!(dest::AbstractArray, src::AbstractArray, perm) + permutedims!(dest, src, perm) + return dest +end + +using DerivableInterfaces: DerivableInterfaces, permuteddims +function DerivableInterfaces.permuteddims(a::KroneckerArray, perm) + return permuteddims(arg1(a), perm) ⊗ permuteddims(arg2(a), perm) +end + +function Base.permutedims!(dest::KroneckerArray, src::KroneckerArray, perm) + # TODO: Error if neither argument is mutable. + _permutedims!!(arg1(dest), arg1(src), perm) + _permutedims!!(arg2(dest), arg2(src), perm) + return dest +end + function flatten(t::Tuple{Tuple,Tuple,Vararg{Tuple}}) return (t[1]..., flatten(Base.tail(t))...) end @@ -128,7 +162,7 @@ function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N} a′ = reshape(a, interleave(size(a), ntuple(one, N))) b′ = reshape(b, interleave(ntuple(one, N), size(b))) c′ = permutedims(a′ .* b′, reverse(ntuple(identity, 2N))) - sz = ntuple(i -> size(a, i) * size(b, i), N) + sz = reverse(ntuple(i -> size(a, i) * size(b, i), N)) return permutedims(reshape(c′, sz), reverse(ntuple(identity, N))) end kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b) @@ -284,6 +318,12 @@ for f in [:transpose, :adjoint, :inv] end end +function Base.reshape( + a::KroneckerArray, ax::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}} +) + return reshape(arg1(a), map(arg1, ax)) ⊗ reshape(arg2(a), map(arg2, ax)) +end + # Allows for customizations for FillArrays. _BroadcastStyle(x) = BroadcastStyle(x) @@ -405,8 +445,8 @@ Broadcast.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a) Broadcast.broadcastable(a::KroneckerBroadcasted) = a Base.copy(a::KroneckerBroadcasted) = copy(arg1(a)) ⊗ copy(arg2(a)) function Base.copyto!(dest::KroneckerArray, a::KroneckerBroadcasted) - copyto!(arg1(dest), copy(arg1(a))) - copyto!(arg2(dest), copy(arg2(a))) + _copyto!!(arg1(dest), arg1(a)) + _copyto!!(arg2(dest), arg2(a)) return dest end function Base.eltype(a::KroneckerBroadcasted) diff --git a/test/Project.toml b/test/Project.toml index f649d4b..f37978a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" @@ -34,6 +35,7 @@ MatrixAlgebraKit = "0.2" SafeTestsets = "0.1" StableRNGs = "1.0" Suppressor = "0.2" +TensorAlgebra = "0.3.10" TensorProducts = "0.1.7" Test = "1.10" TestExtras = "0.3" diff --git a/test/test_basics.jl b/test/test_basics.jl index e17f8c1..b621730 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -46,10 +46,21 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test r[2 × 2] == 5 @test r[2 × 3] == 6 + @test sprint(show, "text/plain", cartesianrange(2 × 3)) == + "Base.OneTo(2) × Base.OneTo(3)\nBase.OneTo(6)" + @test sprint(show, cartesianrange(2 × 3)) == "Base.OneTo(6)" + # CartesianProductUnitRange axes r = cartesianrange((2:3) × (3:4), 2:5) @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) + # CartesianProductUnitRange getindex + r1 = cartesianrange((2:4) × (3:5), 2:10) + r2 = cartesianrange((2:3) × (2:3), 2:5) + @test r1[r2] ≡ cartesianrange((3:4) × (4:5), 3:6) + + @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) + # CartesianProductVector axes r = CartesianProductVector(([2, 4]) × ([3, 5]), [3, 5, 7, 9]) @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) @@ -178,6 +189,17 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test_throws ErrorException imag(a) end + # permutedims + a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) + @test permutedims(a, (2, 1, 3)) == + permutedims(arg1(a), (2, 1, 3)) ⊗ permutedims(arg2(a), (2, 1, 3)) + + # permutedims! + a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) + b = similar(a) + permutedims!(b, a, (2, 1, 3)) + @test b == permutedims(arg1(a), (2, 1, 3)) ⊗ permutedims(arg2(a), (2, 1, 3)) + # Adapt a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) a′ = adapt(JLArray, a) diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index 1dcf58a..db492db 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -1,10 +1,16 @@ using Adapt: adapt -using BlockArrays: Block, BlockRange, mortar +using BlockArrays: Block, BlockRange, blockedrange, blockisequal, mortar using BlockSparseArrays: - BlockIndexVector, BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype + BlockIndexVector, + BlockSparseArray, + BlockSparseMatrix, + blockrange, + blocksparse, + blocktype, + eachblockaxis using FillArrays: Eye, SquareEye using JLArrays: JLArray -using KroneckerArrays: KroneckerArray, ⊗, ×, arg1, arg2 +using KroneckerArrays: KroneckerArray, ⊗, ×, arg1, arg2, cartesianrange using LinearAlgebra: norm using MatrixAlgebraKit: svd_compact, svd_trunc using StableRNGs: StableRNG @@ -17,6 +23,15 @@ arrayts = (Array, JLArray) arrayts, elt in elts + # BlockUnitRange with CartesianProduct blocks + r = blockrange([2 × 3, 3 × 4]) + @test r[Block(1)] ≡ cartesianrange(2 × 3, 1:6) + @test r[Block(2)] ≡ cartesianrange(3 × 4, 7:18) + @test eachblockaxis(r)[1] ≡ cartesianrange(2, 3) + @test eachblockaxis(r)[2] ≡ cartesianrange(3, 4) + @test blockisequal(arg1(r), blockedrange([2, 3])) + @test blockisequal(arg2(r), blockedrange([3, 4])) + dev = adapt(arrayt) r = blockrange([2 × 2, 3 × 3]) d = Dict( diff --git a/test/test_fillarrays.jl b/test/test_fillarrays.jl index 62e0234..04ef6ad 100644 --- a/test/test_fillarrays.jl +++ b/test/test_fillarrays.jl @@ -1,12 +1,15 @@ +using Adapt: adapt using DerivableInterfaces: zero! +using DiagonalArrays: δ using FillArrays: Eye, Zeros -using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, ×, arg1, arg2 +using JLArrays: JLArray, jl +using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, ×, arg1, arg2, cartesianrange using LinearAlgebra: det, norm, pinv using StableRNGs: StableRNG using Test: @test, @test_throws, @testset using TestExtras: @constinferred -@testset "FillArrays.Eye" begin +@testset "FillArrays.Eye, DiagonalArrays.Delta" begin MATRIX_FUNCTIONS = KroneckerArrays.MATRIX_FUNCTIONS if VERSION < v"1.11-" # `cbrt(::AbstractMatrix{<:Real})` was implemented in Julia 1.11. @@ -15,9 +18,130 @@ using TestExtras: @constinferred a = Eye(2) ⊗ randn(3, 3) @test size(a) == (6, 6) - @test a + a == Eye(2) ⊗ (2a.b) - @test 2a == Eye(2) ⊗ (2a.b) - @test a * a == Eye(2) ⊗ (a.b * a.b) + @test a + a == Eye(2) ⊗ (2 * arg2(a)) + @test 2a == Eye(2) ⊗ (2 * arg2(a)) + @test a * a == Eye(2) ⊗ (arg2(a) * arg2(a)) + @test arg1(a[(:) × (:), (:) × (:)]) ≡ Eye(2) + @test arg1(view(a, (:) × (:), (:) × (:))) ≡ Eye(2) + @test arg1(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ Eye(2) + @test arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ Eye(2) + @test arg1(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ Eye(2) + @test arg1(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ Eye(2) + @test arg1(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ Eye(2) + @test arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ + Eye(2) + @test arg1(adapt(JLArray, a)) ≡ Eye(2) + @test arg2(adapt(JLArray, a)) == jl(arg2(a)) + @test arg2(adapt(JLArray, a)) isa JLArray + @test arg1(similar(a, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ Eye(3) + @test arg1(similar(typeof(a), (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ Eye(3) + @test arg1(similar(a, Float32, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ + Eye{Float32}(3) + @test arg1(copy(a)) ≡ Eye(2) + @test arg2(copy(a)) == arg2(a) + b = similar(a) + @test arg1(copyto!(b, a)) ≡ Eye(2) + @test arg2(copyto!(b, a)) == arg2(a) + @test arg1(permutedims(a, (2, 1))) ≡ Eye(2) + @test arg2(permutedims(a, (2, 1))) == permutedims(arg2(a), (2, 1)) + b = similar(a) + @test arg1(permutedims!(b, a, (2, 1))) ≡ Eye(2) + @test arg2(permutedims!(b, a, (2, 1))) == permutedims(arg2(a), (2, 1)) + + a = randn(3, 3) ⊗ Eye(2) + @test size(a) == (6, 6) + @test a + a == (2 * arg1(a)) ⊗ Eye(2) + @test 2a == (2 * arg1(a)) ⊗ Eye(2) + @test a * a == (arg1(a) * arg1(a)) ⊗ Eye(2) + @test arg2(a[(:) × (:), (:) × (:)]) ≡ Eye(2) + @test arg2(view(a, (:) × (:), (:) × (:))) ≡ Eye(2) + @test arg2(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ Eye(2) + @test arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ Eye(2) + @test arg2(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ Eye(2) + @test arg2(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ Eye(2) + @test arg2(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ Eye(2) + @test arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ + Eye(2) + @test arg2(adapt(JLArray, a)) ≡ Eye(2) + @test arg1(adapt(JLArray, a)) == jl(arg1(a)) + @test arg1(adapt(JLArray, a)) isa JLArray + @test arg2(similar(a, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ Eye(3) + @test arg2(similar(typeof(a), (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ Eye(3) + @test arg2(similar(a, Float32, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ + Eye{Float32}(3) + @test arg2(copy(a)) ≡ Eye(2) + @test arg2(copy(a)) == arg2(a) + b = similar(a) + @test arg2(copyto!(b, a)) ≡ Eye(2) + @test arg2(copyto!(b, a)) == arg2(a) + @test arg2(permutedims(a, (2, 1))) ≡ Eye(2) + @test arg1(permutedims(a, (2, 1))) == permutedims(arg1(a), (2, 1)) + b = similar(a) + @test arg2(permutedims!(b, a, (2, 1))) ≡ Eye(2) + @test arg1(permutedims!(b, a, (2, 1))) == permutedims(arg1(a), (2, 1)) + + a = δ(2, 2) ⊗ randn(3, 3) + @test size(a) == (6, 6) + @test a + a == δ(2, 2) ⊗ (2 * arg2(a)) + @test 2a == δ(2, 2) ⊗ (2 * arg2(a)) + @test a * a == δ(2, 2) ⊗ (arg2(a) * arg2(a)) + @test arg1(a[(:) × (:), (:) × (:)]) ≡ δ(2, 2) + @test arg1(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ δ(2, 2) + @test arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ δ(2, 2) + @test arg1(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) + @test arg1(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ δ(2, 2) + @test arg1(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) + @test arg1(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ + δ(2, 2) + @test arg1(adapt(JLArray, a)) ≡ δ(2, 2) + @test arg2(adapt(JLArray, a)) == jl(arg2(a)) + @test arg2(adapt(JLArray, a)) isa JLArray + @test arg1(similar(a, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ δ(3, 3) + @test arg1(similar(typeof(a), (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ δ(3, 3) + @test arg1(similar(a, Float32, (cartesianrange(3 × 2), cartesianrange(3 × 2)))) ≡ + δ(Float32, 3, 3) + @test arg1(copy(a)) ≡ δ(2, 2) + @test arg2(copy(a)) == arg2(a) + b = similar(a) + @test arg1(copyto!(b, a)) ≡ δ(2, 2) + @test arg2(copyto!(b, a)) == arg2(a) + @test arg1(permutedims(a, (2, 1))) ≡ δ(2, 2) + @test arg2(permutedims(a, (2, 1))) == permutedims(arg2(a), (2, 1)) + b = similar(a) + @test arg1(permutedims!(b, a, (2, 1))) ≡ δ(2, 2) + @test arg2(permutedims!(b, a, (2, 1))) == permutedims(arg2(a), (2, 1)) + + a = randn(3, 3) ⊗ δ(2, 2) + @test size(a) == (6, 6) + @test a + a == (2 * arg1(a)) ⊗ δ(2, 2) + @test 2a == (2 * arg1(a)) ⊗ δ(2, 2) + @test a * a == (arg1(a) * arg1(a)) ⊗ δ(2, 2) + @test arg2(a[(:) × (:), (:) × (:)]) ≡ δ(2, 2) + @test arg2(view(a, (:) × (:), (:) × (:))) ≡ δ(2, 2) + @test arg2(a[Base.Slice(Base.OneTo(2)) × (:), (:) × (:)]) ≡ δ(2, 2) + @test arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), (:) × (:))) ≡ δ(2, 2) + @test arg2(a[(:) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) + @test arg2(view(a, (:) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ δ(2, 2) + @test arg2(a[Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:)]) ≡ δ(2, 2) + @test arg2(view(a, Base.Slice(Base.OneTo(2)) × (:), Base.Slice(Base.OneTo(2)) × (:))) ≡ + δ(2, 2) + @test arg2(adapt(JLArray, a)) ≡ δ(2, 2) + @test arg1(adapt(JLArray, a)) == jl(arg1(a)) + @test arg1(adapt(JLArray, a)) isa JLArray + @test arg2(similar(a, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ δ(3, 3) + @test arg2(similar(typeof(a), (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ δ(3, 3) + @test arg2(similar(a, Float32, (cartesianrange(2 × 3), cartesianrange(2 × 3)))) ≡ + δ(Float32, (3, 3)) + @test arg2(copy(a)) ≡ δ(2, 2) + @test arg2(copy(a)) == arg2(a) + b = similar(a) + @test arg2(copyto!(b, a)) ≡ δ(2, 2) + @test arg2(copyto!(b, a)) == arg2(a) + @test arg2(permutedims(a, (2, 1))) ≡ δ(2, 2) + @test arg1(permutedims(a, (2, 1))) == permutedims(arg1(a), (2, 1)) + b = similar(a) + @test arg2(permutedims!(b, a, (2, 1))) ≡ δ(2, 2) + @test arg1(permutedims!(b, a, (2, 1))) == permutedims(arg1(a), (2, 1)) # Views a = @constinferred(Eye(2) ⊗ randn(3, 3)) @@ -204,6 +328,24 @@ using TestExtras: @constinferred @test fa.b isa Eye @test det(a) ≈ det(collect(a)) ≈ 1 + + # permutedims + a = Eye(2, 2) ⊗ randn(3, 3) + @test permutedims(a, (2, 1)) == Eye(2, 2) ⊗ permutedims(arg2(a), (2, 1)) + + a = randn(2, 2) ⊗ Eye(3, 3) + @test permutedims(a, (2, 1)) == permutedims(arg1(a), (2, 1)) ⊗ Eye(3, 3) + + # permutedims! + a = Eye(2, 2) ⊗ randn(3, 3) + b = similar(a) + permutedims!(b, a, (2, 1)) + @test b == Eye(2, 2) ⊗ permutedims(arg2(a), (2, 1)) + + a = randn(3, 3) ⊗ Eye(2, 2) + b = similar(a) + permutedims!(b, a, (2, 1)) + @test b == permutedims(arg1(a), (2, 1)) ⊗ Eye(2, 2) end @testset "FillArrays.Zeros" begin diff --git a/test/test_tensoralgebra.jl b/test/test_tensoralgebra.jl new file mode 100644 index 0000000..35cac2a --- /dev/null +++ b/test/test_tensoralgebra.jl @@ -0,0 +1,10 @@ +using TensorAlgebra: matricize, unmatricize +using KroneckerArrays: ⊗, arg1, arg2 +using Test: @test, @testset + +@testset "TensorAlgebraExt" begin + a = randn(2, 2, 2) ⊗ randn(3, 3, 3) + m = matricize(a, (1, 2), (3,)) + @test m == matricize(arg1(a), (1, 2), (3,)) ⊗ matricize(arg2(a), (1, 2), (3,)) + @test unmatricize(m, (axes(a, 1), axes(a, 2)), (axes(a, 3),)) == a +end