Skip to content

Commit 0b5388f

Browse files
authored
Implement some operations with Zeros (#22)
1 parent 5e70c2f commit 0b5388f

File tree

5 files changed

+95
-10
lines changed

5 files changed

+95
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.16"
4+
version = "0.1.17"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/fillarrays/kroneckerarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ function _BroadcastStyle(::Type{<:Eye})
228228
return EyeStyle()
229229
end
230230
Base.BroadcastStyle(style1::EyeStyle, style2::EyeStyle) = EyeStyle()
231+
Base.BroadcastStyle(style1::EyeStyle, style2::DefaultArrayStyle) = style2
231232

232233
function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type)
233234
return Eye{elt}(axes(bc))

src/kroneckerarray.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,65 @@ function Base.broadcasted(
323323
return broadcasted(style, Base.Fix2(/, f.args[2]), a)
324324
end
325325

326+
# Simplification rules similar to those for FillArrays.jl:
327+
# https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl
328+
using FillArrays: Zeros
329+
function Base.broadcasted(
330+
style::KroneckerStyle,
331+
::typeof(+),
332+
a::KroneckerArray,
333+
b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
334+
)
335+
# TODO: Promote the element types.
336+
return a
337+
end
338+
function Base.broadcasted(
339+
style::KroneckerStyle,
340+
::typeof(+),
341+
a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
342+
b::KroneckerArray,
343+
)
344+
# TODO: Promote the element types.
345+
return b
346+
end
347+
function Base.broadcasted(
348+
style::KroneckerStyle,
349+
::typeof(+),
350+
a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
351+
b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
352+
)
353+
# TODO: Promote the element types and axes.
354+
return b
355+
end
356+
function Base.broadcasted(
357+
style::KroneckerStyle,
358+
::typeof(-),
359+
a::KroneckerArray,
360+
b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
361+
)
362+
# TODO: Promote the element types.
363+
return a
364+
end
365+
function Base.broadcasted(
366+
style::KroneckerStyle,
367+
::typeof(-),
368+
a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
369+
b::KroneckerArray,
370+
)
371+
# TODO: Promote the element types.
372+
# TODO: Return `broadcasted(-, b)`.
373+
return -b
374+
end
375+
function Base.broadcasted(
376+
style::KroneckerStyle,
377+
::typeof(-),
378+
a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
379+
b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
380+
)
381+
# TODO: Promote the element types and axes.
382+
return b
383+
end
384+
326385
# TODO: Define by converting to a broadcast expession (with MapBroadcast.jl)
327386
# and then constructing the output with `similar`.
328387
function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...)

test/test_blocksparsearrays.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,13 @@ end
159159
@test_broken a[Block.(1:2), Block(2)]
160160

161161
@testset "Block deficient" begin
162-
d = Dict(Block(1, 1) => Eye{elt}(2, 2) dev(randn(elt, 2, 2)))
163-
a = @constinferred dev(blocksparse(d, r, r))
162+
da = Dict(Block(1, 1) => Eye{elt}(2, 2) dev(randn(elt, 2, 2)))
163+
a = @constinferred dev(blocksparse(da, r, r))
164164

165-
d = Dict(Block(2, 2) => Eye{elt}(3, 3) dev(randn(elt, 3, 3)))
166-
b = @constinferred dev(blocksparse(d, r, r))
165+
db = Dict(Block(2, 2) => Eye{elt}(3, 3) dev(randn(elt, 3, 3)))
166+
b = @constinferred dev(blocksparse(db, r, r))
167167

168-
@test_broken a + b
169-
# @test Array(a + b) ≈ Array(a) + Array(b)
170-
# @test Array(2a) ≈ 2Array(a)
168+
@test Array(a + b) Array(a) + Array(b)
169+
@test Array(2a) 2Array(a)
171170
end
172171
end

test/test_fillarrays.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
using DerivableInterfaces: zero!
2-
using FillArrays: Eye
2+
using FillArrays: Eye, Zeros
33
using KroneckerArrays: KroneckerArrays, KroneckerArray,
44
using LinearAlgebra: det, norm, pinv
55
using StableRNGs: StableRNG
6-
using Test: @test, @testset
6+
using Test: @test, @test_throws, @testset
77

88
@testset "FillArrays.Eye" begin
99
MATRIX_FUNCTIONS = KroneckerArrays.MATRIX_FUNCTIONS
@@ -190,3 +190,29 @@ using Test: @test, @testset
190190

191191
@test det(a) det(collect(a)) 1
192192
end
193+
194+
@testset "FillArrays.Zeros" begin
195+
a = randn(2, 2) randn(2, 2)
196+
b = Zeros(2, 2) Zeros(2, 2)
197+
for (x, y) in ((a, b), (b, a))
198+
@test x + y == a
199+
@test x .+ y == a
200+
@test map!(+, similar(a), x, y) == a
201+
@test (similar(a) .= x .+ y) == a
202+
end
203+
204+
@test a - b == a
205+
@test a .- b == a
206+
@test map!(-, similar(a), a, b) == a
207+
@test (similar(a) .= a .- b) == a
208+
209+
@test b - a == -a
210+
@test b .- a == -a
211+
@test map!(-, similar(a), b, a) == -a
212+
@test (similar(a) .= b .- a) == -a
213+
214+
@test b + b == b
215+
@test b .+ b == b
216+
@test b - b == b
217+
@test b .- b == b
218+
end

0 commit comments

Comments
 (0)