diff --git a/Project.toml b/Project.toml index 8e8934f..f0ba43d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.1" +version = "0.1.2" [deps] DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index 0783396..c072dbe 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -112,6 +112,11 @@ function Base.similar( ) where {A,B} return similar(A, map(ax -> ax.product.a, axs)) ⊗ similar(B, map(ax -> ax.product.b, axs)) end +function Base.similar( + ::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, sz::Tuple{Int,Vararg{Int}} +) where {A,B} + return similar(promote_type(A, B), sz) +end Base.collect(a::KroneckerArray) = kron(a.a, a.b) diff --git a/test/Project.toml b/test/Project.toml index 772eff2..790a8e3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,6 +11,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Aqua = "0.8" KroneckerArrays = "0.1" LinearAlgebra = "1.10" +MatrixAlgebraKit = "0.2" SafeTestsets = "0.1" Suppressor = "0.2" Test = "1.10" diff --git a/test/test_basics.jl b/test/test_basics.jl index bb44654..17892cc 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -11,6 +11,8 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c = a.a ⊗ b.b + @test similar(typeof(a), (2, 3)) isa Matrix{elt} + @test size(similar(typeof(a), (2, 3))) == (2, 3) @test a[1 × 1, 1 × 1] == a.a[1, 1] * a.b[1, 1] @test a[1 × 3, 2 × 1] == a.a[1, 2] * a.b[3, 1] @test a[1 × (2:3), 2 × 1] == a.a[1, 2] * a.b[2:3, 1]