diff --git a/Project.toml b/Project.toml index 365054b..4e96312 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.16" +version = "0.1.17" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index 0a39106..f8d17b1 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -228,6 +228,7 @@ function _BroadcastStyle(::Type{<:Eye}) return EyeStyle() end Base.BroadcastStyle(style1::EyeStyle, style2::EyeStyle) = EyeStyle() +Base.BroadcastStyle(style1::EyeStyle, style2::DefaultArrayStyle) = style2 function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type) return Eye{elt}(axes(bc)) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 05f9b26..b20f9b3 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -323,6 +323,65 @@ function Base.broadcasted( return broadcasted(style, Base.Fix2(/, f.args[2]), a) end +# Simplification rules similar to those for FillArrays.jl: +# https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl +using FillArrays: Zeros +function Base.broadcasted( + style::KroneckerStyle, + ::typeof(+), + a::KroneckerArray, + b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +) + # TODO: Promote the element types. + return a +end +function Base.broadcasted( + style::KroneckerStyle, + ::typeof(+), + a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, + b::KroneckerArray, +) + # TODO: Promote the element types. + return b +end +function Base.broadcasted( + style::KroneckerStyle, + ::typeof(+), + a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, + b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +) + # TODO: Promote the element types and axes. + return b +end +function Base.broadcasted( + style::KroneckerStyle, + ::typeof(-), + a::KroneckerArray, + b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +) + # TODO: Promote the element types. + return a +end +function Base.broadcasted( + style::KroneckerStyle, + ::typeof(-), + a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, + b::KroneckerArray, +) + # TODO: Promote the element types. + # TODO: Return `broadcasted(-, b)`. + return -b +end +function Base.broadcasted( + style::KroneckerStyle, + ::typeof(-), + a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, + b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +) + # TODO: Promote the element types and axes. + return b +end + # TODO: Define by converting to a broadcast expession (with MapBroadcast.jl) # and then constructing the output with `similar`. function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...) diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index d5d0e56..1d37e15 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -159,14 +159,13 @@ end @test_broken a[Block.(1:2), Block(2)] @testset "Block deficient" begin - d = Dict(Block(1, 1) => Eye{elt}(2, 2) ⊗ dev(randn(elt, 2, 2))) - a = @constinferred dev(blocksparse(d, r, r)) + da = Dict(Block(1, 1) => Eye{elt}(2, 2) ⊗ dev(randn(elt, 2, 2))) + a = @constinferred dev(blocksparse(da, r, r)) - d = Dict(Block(2, 2) => Eye{elt}(3, 3) ⊗ dev(randn(elt, 3, 3))) - b = @constinferred dev(blocksparse(d, r, r)) + db = Dict(Block(2, 2) => Eye{elt}(3, 3) ⊗ dev(randn(elt, 3, 3))) + b = @constinferred dev(blocksparse(db, r, r)) - @test_broken a + b - # @test Array(a + b) ≈ Array(a) + Array(b) - # @test Array(2a) ≈ 2Array(a) + @test Array(a + b) ≈ Array(a) + Array(b) + @test Array(2a) ≈ 2Array(a) end end diff --git a/test/test_fillarrays.jl b/test/test_fillarrays.jl index 001cda9..f852ef9 100644 --- a/test/test_fillarrays.jl +++ b/test/test_fillarrays.jl @@ -1,9 +1,9 @@ using DerivableInterfaces: zero! -using FillArrays: Eye +using FillArrays: Eye, Zeros using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗ using LinearAlgebra: det, norm, pinv using StableRNGs: StableRNG -using Test: @test, @testset +using Test: @test, @test_throws, @testset @testset "FillArrays.Eye" begin MATRIX_FUNCTIONS = KroneckerArrays.MATRIX_FUNCTIONS @@ -190,3 +190,29 @@ using Test: @test, @testset @test det(a) ≈ det(collect(a)) ≈ 1 end + +@testset "FillArrays.Zeros" begin + a = randn(2, 2) ⊗ randn(2, 2) + b = Zeros(2, 2) ⊗ Zeros(2, 2) + for (x, y) in ((a, b), (b, a)) + @test x + y == a + @test x .+ y == a + @test map!(+, similar(a), x, y) == a + @test (similar(a) .= x .+ y) == a + end + + @test a - b == a + @test a .- b == a + @test map!(-, similar(a), a, b) == a + @test (similar(a) .= a .- b) == a + + @test b - a == -a + @test b .- a == -a + @test map!(-, similar(a), b, a) == -a + @test (similar(a) .= b .- a) == -a + + @test b + b == b + @test b .+ b == b + @test b - b == b + @test b .- b == b +end