Skip to content

Commit 826a8cf

Browse files
authored
Define more broadcasting operations (#20)
1 parent ca8be69 commit 826a8cf

File tree

7 files changed

+143
-25
lines changed

7 files changed

+143
-25
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.14"
4+
version = "0.1.15"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -10,6 +10,7 @@ DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
1010
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1111
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13+
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1314
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1415

1516
[weakdeps]
@@ -28,5 +29,6 @@ DiagonalArrays = "0.3.5"
2829
FillArrays = "1.13.0"
2930
GPUArraysCore = "0.2.0"
3031
LinearAlgebra = "1.10"
32+
MapBroadcast = "0.1.9"
3133
MatrixAlgebraKit = "0.2.0"
3234
julia = "1.10"

src/cartesianproduct.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,12 @@ for f in (:+, :-)
7575
end
7676
end
7777
end
78+
79+
using Base.Broadcast: axistype
80+
function Base.Broadcast.axistype(
81+
r1::CartesianProductUnitRange, r2::CartesianProductUnitRange
82+
)
83+
prod = axistype(arg1(r1), arg1(r2)) × axistype(arg2(r1), arg2(r2))
84+
range = axistype(unproduct(r1), unproduct(r2))
85+
return cartesianrange(prod, range)
86+
end

src/fillarrays/kroneckerarray.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
using FillArrays: FillArrays, Zeros
2+
function FillArrays.fillsimilar(
3+
a::Zeros{T},
4+
ax::Tuple{
5+
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
6+
},
7+
) where {T}
8+
return Zeros{T}(arg1.(ax)) Zeros{T}(arg2.(ax))
9+
end
10+
111
using FillArrays: RectDiagonal, OnesVector
212
const RectEye{T,V<:OnesVector{T},Axes} = RectDiagonal{T,V,Axes}
313

@@ -208,3 +218,17 @@ end
208218
function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeEye, a::EyeEye)
209219
return error("Can't write in-place.")
210220
end
221+
222+
using Base.Broadcast:
223+
AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted
224+
225+
struct EyeStyle <: AbstractArrayStyle{2} end
226+
EyeStyle(::Val{2}) = EyeStyle()
227+
function _BroadcastStyle(::Type{<:Eye})
228+
return EyeStyle()
229+
end
230+
Base.BroadcastStyle(style1::EyeStyle, style2::EyeStyle) = EyeStyle()
231+
232+
function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type)
233+
return Eye{elt}(axes(bc))
234+
end

src/kroneckerarray.jl

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ end
226226
for op in (:+, :-)
227227
@eval begin
228228
function Base.$op(a::KroneckerArray, b::KroneckerArray)
229+
iszero(a) && return $op(b)
230+
iszero(b) && return a
229231
if a.b == b.b
230232
return $op(a.a, b.a) a.b
231233
elseif a.a == b.a
@@ -241,8 +243,15 @@ for op in (:+, :-)
241243
end
242244
end
243245

244-
using Base.Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted
246+
# Allows for customizations for FillArrays.
247+
_BroadcastStyle(x) = BroadcastStyle(x)
248+
249+
using Base.Broadcast: Broadcast, AbstractArrayStyle, BroadcastStyle, Broadcasted
245250
struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end
251+
arg1(::Type{<:KroneckerStyle{<:Any,A}}) where {A} = A
252+
arg1(style::KroneckerStyle) = arg1(typeof(style))
253+
arg2(::Type{<:KroneckerStyle{<:Any,B}}) where {B} = B
254+
arg2(style::KroneckerStyle) = arg2(typeof(style))
246255
function KroneckerStyle{N}(a::BroadcastStyle, b::BroadcastStyle) where {N}
247256
return KroneckerStyle{N,a,b}()
248257
end
@@ -253,30 +262,69 @@ function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M}
253262
return KroneckerStyle{M,typeof(A)(v),typeof(B)(v)}()
254263
end
255264
function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any,N,A,B}}) where {N,A,B}
256-
return KroneckerStyle{N}(BroadcastStyle(A), BroadcastStyle(B))
265+
return KroneckerStyle{N}(_BroadcastStyle(A), _BroadcastStyle(B))
257266
end
258267
function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N}
259-
return KroneckerStyle{N}(
260-
BroadcastStyle(style1.a, style2.a), BroadcastStyle(style1.b, style2.b)
261-
)
268+
style_a = BroadcastStyle(arg1(style1), arg1(style2))
269+
(style_a isa Broadcast.Unknown) && return Broadcast.Unknown()
270+
style_b = BroadcastStyle(arg2(style1), arg2(style2))
271+
(style_b isa Broadcast.Unknown) && return Broadcast.Unknown()
272+
return KroneckerStyle{N}(style_a, style_b)
262273
end
263274
function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type) where {N,A,B}
264-
ax_a = map(ax -> ax.product.a, axes(bc))
265-
ax_b = map(ax -> ax.product.b, axes(bc))
275+
ax_a = arg1.(axes(bc))
276+
ax_b = arg2.(axes(bc))
266277
bc_a = Broadcasted(A, nothing, (), ax_a)
267278
bc_b = Broadcasted(B, nothing, (), ax_b)
268279
a = similar(bc_a, elt)
269280
b = similar(bc_b, elt)
270281
return a b
271282
end
283+
# Fallback definition of broadcasting falls back to `map` but assumes
284+
# inputs have been canonicalized to a map-compatible expression already,
285+
# for example by absorbing scalar arguments into the function.
272286
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:KroneckerStyle})
273-
return throw(
274-
ArgumentError(
275-
"Arbitrary broadcasting is not supported for KroneckerArrays since they might not preserve the Kronecker structure.",
276-
),
277-
)
287+
allequal(axes, bc.args) || throw(ArgumentError("Broadcasted axes must be equal."))
288+
map!(bc.f, dest, bc.args...)
289+
return dest
278290
end
279291

292+
# Broadcast rewrite rules. Canonicalize inputs to absorb scalar inputs into the
293+
# function.
294+
function Base.broadcasted(style::KroneckerStyle, ::typeof(*), a::Number, b::KroneckerArray)
295+
return broadcasted(style, Base.Fix1(*, a), b)
296+
end
297+
function Base.broadcasted(style::KroneckerStyle, ::typeof(*), a::KroneckerArray, b::Number)
298+
return broadcasted(style, Base.Fix2(*, b), a)
299+
end
300+
function Base.broadcasted(style::KroneckerStyle, ::typeof(/), a::KroneckerArray, b::Number)
301+
return broadcasted(style, Base.Fix2(/, b), a)
302+
end
303+
using MapBroadcast: MapBroadcast, MapFunction
304+
function Base.broadcasted(
305+
style::KroneckerStyle,
306+
f::MapFunction{typeof(*),<:Tuple{<:Number,MapBroadcast.Arg}},
307+
a::KroneckerArray,
308+
)
309+
return broadcasted(style, Base.Fix1(*, f.args[1]), a)
310+
end
311+
function Base.broadcasted(
312+
style::KroneckerStyle,
313+
f::MapFunction{typeof(*),<:Tuple{MapBroadcast.Arg,<:Number}},
314+
a::KroneckerArray,
315+
)
316+
return broadcasted(style, Base.Fix2(*, f.args[2]), a)
317+
end
318+
function Base.broadcasted(
319+
style::KroneckerStyle,
320+
f::MapFunction{typeof(/),<:Tuple{MapBroadcast.Arg,<:Number}},
321+
a::KroneckerArray,
322+
)
323+
return broadcasted(style, Base.Fix2(/, f.args[2]), a)
324+
end
325+
326+
# TODO: Define by converting to a broadcast expession (with MapBroadcast.jl)
327+
# and then constructing the output with `similar`.
280328
function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...)
281329
return throw(
282330
ArgumentError(
@@ -312,6 +360,8 @@ for f in [:+, :-]
312360
function Base.map!(
313361
::typeof($f), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray
314362
)
363+
iszero(b) && return map!(identity, dest, a)
364+
iszero(a) && return map!($f, dest, b)
315365
if a.b == b.b
316366
map!($f, dest.a, a.a, b.a)
317367
map!(identity, dest.b, a.b)
@@ -350,6 +400,15 @@ for op in [:*, :/]
350400
end
351401
end
352402
end
403+
for f in [:+, :-]
404+
@eval begin
405+
function Base.map!(::typeof($f), dest::KroneckerArray, src::KroneckerArray)
406+
map!($f, dest.a, src.a)
407+
map!(identity, dest.b, src.b)
408+
return dest
409+
end
410+
end
411+
end
353412

354413
using DiagonalArrays: DiagonalArrays, diagonal
355414
function DiagonalArrays.diagonal(a::KroneckerArray)

src/linearalgebra.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ using LinearAlgebra:
1515
svdvals,
1616
tr
1717

18+
using LinearAlgebra: LinearAlgebra
19+
function KroneckerArray(J::LinearAlgebra.UniformScaling, ax::Tuple)
20+
return Eye{eltype(J)}(arg1.(ax)) Eye{eltype(J)}(arg2.(ax))
21+
end
22+
function Base.copyto!(a::KroneckerArray, J::LinearAlgebra.UniformScaling)
23+
copyto!(a, KroneckerArray(J, axes(a)))
24+
return a
25+
end
26+
1827
using LinearAlgebra: LinearAlgebra, pinv
1928
function LinearAlgebra.pinv(a::KroneckerArray; kwargs...)
2029
return pinv(a.a; kwargs...) pinv(a.b; kwargs...)

test/test_basics.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,15 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
8787
a′ = similar(a)
8888
@test_throws "not supported" a′ .= sin.(a)
8989
a′ = similar(a)
90-
@test_broken a′ .= 2 .* a
90+
a′ .= 2 .* a
91+
@test collect(a′) 2 * collect(a)
9192
bc = broadcasted(+, a, a)
9293
@test bc.style === style
9394
@test similar(bc, elt) isa KroneckerArray{elt,2,typeof(a.a),typeof(a.b)}
94-
@test_broken copy(bc)
95+
@test collect(copy(bc)) 2 * collect(a)
9596
bc = broadcasted(*, 2, a)
9697
@test bc.style === style
97-
@test_broken copy(bc)
98+
@test collect(copy(bc)) 2 * collect(a)
9899

99100
# Mapping
100101
a = randn(elt, 2, 2) randn(elt, 3, 3)

test/test_blocksparsearrays.jl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,11 @@ arrayts = (Array, JLArray)
6464
@test_broken inv(a)
6565
end
6666

67-
if (VERSION v"1.11-" && arrayt === Array && elt <: Complex) ||
68-
(arrayt === Array && elt <: Real)
67+
if arrayt === Array
6968
u, s, v = svd_compact(a)
7069
@test Array(u * s * v) Array(a)
7170
else
72-
# Broken on GPU and for complex, investigate.
71+
# Broken on GPU.
7372
@test_broken svd_compact(a)
7473
end
7574

@@ -135,19 +134,34 @@ end
135134
@test_broken exp(a)
136135
end
137136

138-
if VERSION < v"1.11-" && elt <: Complex
139-
# Broken because of type stability issue in Julia v1.10.
140-
@test_broken svd_compact(a)
141-
elseif arrayt === Array
137+
## if VERSION < v"1.11-" && elt <: Complex
138+
## # Broken because of type stability issue in Julia v1.10.
139+
## @test_broken svd_compact(a)
140+
if arrayt === Array
142141
u, s, v = svd_compact(a)
143142
@test u * s * v a
144-
@test blocktype(u) === blocktype(a)
145-
@test blocktype(v) === blocktype(a)
143+
@test blocktype(u) >: blocktype(u)
144+
@test eltype(u) === eltype(a)
145+
@test blocktype(v) >: blocktype(a)
146+
@test eltype(v) === eltype(a)
147+
@test eltype(s) === real(eltype(a))
146148
else
147149
@test_broken svd_compact(a)
148150
end
149151

150152
# Broken operations
151153
@test_broken inv(a)
152154
@test_broken a[Block.(1:2), Block(2)]
155+
156+
@testset "Block deficient" begin
157+
d = Dict(Block(1, 1) => Eye{elt}(2, 2) dev(randn(elt, 2, 2)))
158+
a = @constinferred dev(blocksparse(d, r, r))
159+
160+
d = Dict(Block(2, 2) => Eye{elt}(3, 3) dev(randn(elt, 3, 3)))
161+
b = @constinferred dev(blocksparse(d, r, r))
162+
163+
@test_broken a + b
164+
# @test Array(a + b) ≈ Array(a) + Array(b)
165+
# @test Array(2a) ≈ 2Array(a)
166+
end
153167
end

0 commit comments

Comments
 (0)