Skip to content

Define more broadcasting operations #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.14"
version = "0.1.15"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -10,6 +10,7 @@ DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"

[weakdeps]
Expand All @@ -28,5 +29,6 @@ DiagonalArrays = "0.3.5"
FillArrays = "1.13.0"
GPUArraysCore = "0.2.0"
LinearAlgebra = "1.10"
MapBroadcast = "0.1.9"
MatrixAlgebraKit = "0.2.0"
julia = "1.10"
9 changes: 9 additions & 0 deletions src/cartesianproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,12 @@ for f in (:+, :-)
end
end
end

using Base.Broadcast: axistype
function Base.Broadcast.axistype(
r1::CartesianProductUnitRange, r2::CartesianProductUnitRange
)
prod = axistype(arg1(r1), arg1(r2)) × axistype(arg2(r1), arg2(r2))
range = axistype(unproduct(r1), unproduct(r2))
return cartesianrange(prod, range)
end
24 changes: 24 additions & 0 deletions src/fillarrays/kroneckerarray.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
using FillArrays: FillArrays, Zeros
function FillArrays.fillsimilar(
a::Zeros{T},
ax::Tuple{
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
},
) where {T}
return Zeros{T}(arg1.(ax)) ⊗ Zeros{T}(arg2.(ax))
end

using FillArrays: RectDiagonal, OnesVector
const RectEye{T,V<:OnesVector{T},Axes} = RectDiagonal{T,V,Axes}

Expand Down Expand Up @@ -208,3 +218,17 @@ end
function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeEye, a::EyeEye)
return error("Can't write in-place.")
end

using Base.Broadcast:
AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted

struct EyeStyle <: AbstractArrayStyle{2} end
EyeStyle(::Val{2}) = EyeStyle()
function _BroadcastStyle(::Type{<:Eye})
return EyeStyle()
end
Base.BroadcastStyle(style1::EyeStyle, style2::EyeStyle) = EyeStyle()

function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type)
return Eye{elt}(axes(bc))
end
83 changes: 71 additions & 12 deletions src/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ end
for op in (:+, :-)
@eval begin
function Base.$op(a::KroneckerArray, b::KroneckerArray)
iszero(a) && return $op(b)
iszero(b) && return a
if a.b == b.b
return $op(a.a, b.a) ⊗ a.b
elseif a.a == b.a
Expand All @@ -241,8 +243,15 @@ for op in (:+, :-)
end
end

using Base.Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted
# Allows for customizations for FillArrays.
_BroadcastStyle(x) = BroadcastStyle(x)

using Base.Broadcast: Broadcast, AbstractArrayStyle, BroadcastStyle, Broadcasted
struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end
arg1(::Type{<:KroneckerStyle{<:Any,A}}) where {A} = A
arg1(style::KroneckerStyle) = arg1(typeof(style))
arg2(::Type{<:KroneckerStyle{<:Any,B}}) where {B} = B
arg2(style::KroneckerStyle) = arg2(typeof(style))
function KroneckerStyle{N}(a::BroadcastStyle, b::BroadcastStyle) where {N}
return KroneckerStyle{N,a,b}()
end
Expand All @@ -253,30 +262,69 @@ function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M}
return KroneckerStyle{M,typeof(A)(v),typeof(B)(v)}()
end
function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any,N,A,B}}) where {N,A,B}
return KroneckerStyle{N}(BroadcastStyle(A), BroadcastStyle(B))
return KroneckerStyle{N}(_BroadcastStyle(A), _BroadcastStyle(B))
end
function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N}
return KroneckerStyle{N}(
BroadcastStyle(style1.a, style2.a), BroadcastStyle(style1.b, style2.b)
)
style_a = BroadcastStyle(arg1(style1), arg1(style2))
(style_a isa Broadcast.Unknown) && return Broadcast.Unknown()
style_b = BroadcastStyle(arg2(style1), arg2(style2))
(style_b isa Broadcast.Unknown) && return Broadcast.Unknown()
return KroneckerStyle{N}(style_a, style_b)
end
function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type) where {N,A,B}
ax_a = map(ax -> ax.product.a, axes(bc))
ax_b = map(ax -> ax.product.b, axes(bc))
ax_a = arg1.(axes(bc))
ax_b = arg2.(axes(bc))
bc_a = Broadcasted(A, nothing, (), ax_a)
bc_b = Broadcasted(B, nothing, (), ax_b)
a = similar(bc_a, elt)
b = similar(bc_b, elt)
return a ⊗ b
end
# Fallback definition of broadcasting falls back to `map` but assumes
# inputs have been canonicalized to a map-compatible expression already,
# for example by absorbing scalar arguments into the function.
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:KroneckerStyle})
return throw(
ArgumentError(
"Arbitrary broadcasting is not supported for KroneckerArrays since they might not preserve the Kronecker structure.",
),
)
allequal(axes, bc.args) || throw(ArgumentError("Broadcasted axes must be equal."))
map!(bc.f, dest, bc.args...)
return dest
end

# Broadcast rewrite rules. Canonicalize inputs to absorb scalar inputs into the
# function.
function Base.broadcasted(style::KroneckerStyle, ::typeof(*), a::Number, b::KroneckerArray)
return broadcasted(style, Base.Fix1(*, a), b)
end
function Base.broadcasted(style::KroneckerStyle, ::typeof(*), a::KroneckerArray, b::Number)
return broadcasted(style, Base.Fix2(*, b), a)
end
function Base.broadcasted(style::KroneckerStyle, ::typeof(/), a::KroneckerArray, b::Number)
return broadcasted(style, Base.Fix2(/, b), a)
end
using MapBroadcast: MapBroadcast, MapFunction
function Base.broadcasted(
style::KroneckerStyle,
f::MapFunction{typeof(*),<:Tuple{<:Number,MapBroadcast.Arg}},
a::KroneckerArray,
)
return broadcasted(style, Base.Fix1(*, f.args[1]), a)
end
function Base.broadcasted(
style::KroneckerStyle,
f::MapFunction{typeof(*),<:Tuple{MapBroadcast.Arg,<:Number}},
a::KroneckerArray,
)
return broadcasted(style, Base.Fix2(*, f.args[2]), a)
end
function Base.broadcasted(
style::KroneckerStyle,
f::MapFunction{typeof(/),<:Tuple{MapBroadcast.Arg,<:Number}},
a::KroneckerArray,
)
return broadcasted(style, Base.Fix2(/, f.args[2]), a)
end

# TODO: Define by converting to a broadcast expession (with MapBroadcast.jl)
# and then constructing the output with `similar`.
function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...)
return throw(
ArgumentError(
Expand Down Expand Up @@ -312,6 +360,8 @@ for f in [:+, :-]
function Base.map!(
::typeof($f), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray
)
iszero(b) && return map!(identity, dest, a)
iszero(a) && return map!($f, dest, b)
if a.b == b.b
map!($f, dest.a, a.a, b.a)
map!(identity, dest.b, a.b)
Expand Down Expand Up @@ -350,6 +400,15 @@ for op in [:*, :/]
end
end
end
for f in [:+, :-]
@eval begin
function Base.map!(::typeof($f), dest::KroneckerArray, src::KroneckerArray)
map!($f, dest.a, src.a)
map!(identity, dest.b, src.b)
return dest
end
end
end

using DiagonalArrays: DiagonalArrays, diagonal
function DiagonalArrays.diagonal(a::KroneckerArray)
Expand Down
9 changes: 9 additions & 0 deletions src/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ using LinearAlgebra:
svdvals,
tr

using LinearAlgebra: LinearAlgebra
function KroneckerArray(J::LinearAlgebra.UniformScaling, ax::Tuple)
return Eye{eltype(J)}(arg1.(ax)) ⊗ Eye{eltype(J)}(arg2.(ax))
end
function Base.copyto!(a::KroneckerArray, J::LinearAlgebra.UniformScaling)
copyto!(a, KroneckerArray(J, axes(a)))
return a
end

using LinearAlgebra: LinearAlgebra, pinv
function LinearAlgebra.pinv(a::KroneckerArray; kwargs...)
return pinv(a.a; kwargs...) ⊗ pinv(a.b; kwargs...)
Expand Down
7 changes: 4 additions & 3 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,15 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
a′ = similar(a)
@test_throws "not supported" a′ .= sin.(a)
a′ = similar(a)
@test_broken a′ .= 2 .* a
a′ .= 2 .* a
@test collect(a′) ≈ 2 * collect(a)
bc = broadcasted(+, a, a)
@test bc.style === style
@test similar(bc, elt) isa KroneckerArray{elt,2,typeof(a.a),typeof(a.b)}
@test_broken copy(bc)
@test collect(copy(bc)) ≈ 2 * collect(a)
bc = broadcasted(*, 2, a)
@test bc.style === style
@test_broken copy(bc)
@test collect(copy(bc)) ≈ 2 * collect(a)

# Mapping
a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
Expand Down
32 changes: 23 additions & 9 deletions test/test_blocksparsearrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,11 @@ arrayts = (Array, JLArray)
@test_broken inv(a)
end

if (VERSION ≤ v"1.11-" && arrayt === Array && elt <: Complex) ||
(arrayt === Array && elt <: Real)
if arrayt === Array
u, s, v = svd_compact(a)
@test Array(u * s * v) ≈ Array(a)
else
# Broken on GPU and for complex, investigate.
# Broken on GPU.
@test_broken svd_compact(a)
end

Expand Down Expand Up @@ -135,19 +134,34 @@ end
@test_broken exp(a)
end

if VERSION < v"1.11-" && elt <: Complex
# Broken because of type stability issue in Julia v1.10.
@test_broken svd_compact(a)
elseif arrayt === Array
## if VERSION < v"1.11-" && elt <: Complex
## # Broken because of type stability issue in Julia v1.10.
## @test_broken svd_compact(a)
if arrayt === Array
u, s, v = svd_compact(a)
@test u * s * v ≈ a
@test blocktype(u) === blocktype(a)
@test blocktype(v) === blocktype(a)
@test blocktype(u) >: blocktype(u)
@test eltype(u) === eltype(a)
@test blocktype(v) >: blocktype(a)
@test eltype(v) === eltype(a)
@test eltype(s) === real(eltype(a))
else
@test_broken svd_compact(a)
end

# Broken operations
@test_broken inv(a)
@test_broken a[Block.(1:2), Block(2)]

@testset "Block deficient" begin
d = Dict(Block(1, 1) => Eye{elt}(2, 2) ⊗ dev(randn(elt, 2, 2)))
a = @constinferred dev(blocksparse(d, r, r))

d = Dict(Block(2, 2) => Eye{elt}(3, 3) ⊗ dev(randn(elt, 3, 3)))
b = @constinferred dev(blocksparse(d, r, r))

@test_broken a + b
# @test Array(a + b) ≈ Array(a) + Array(b)
# @test Array(2a) ≈ 2Array(a)
end
end
Loading