Skip to content

Commit 76cc220

Browse files
mtfishmanclaude
andauthored
Strict axis equality, remove space_isequal (#138)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c3bcb89 commit 76cc220

17 files changed

+341
-400
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "GradedArrays"
22
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
3-
version = "0.6.23"
3+
version = "0.7.0"
44
authors = ["ITensor developers <support@itensor.org> and contributors"]
55

66
[workspace]

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ path = ".."
99

1010
[compat]
1111
Documenter = "1"
12-
GradedArrays = "0.6"
12+
GradedArrays = "0.7"
1313
ITensorFormatter = "0.2.27"
1414
Literate = "2"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
55
path = ".."
66

77
[compat]
8-
GradedArrays = "0.6"
8+
GradedArrays = "0.7"

src/GradedArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ export gradedrange
1111
export dual, flip, gradedrange, isdual,
1212
sector, sector_multiplicities, sector_multiplicity,
1313
sectorrange, sectors, sector_type,
14-
space_isequal, ungrade
14+
ungrade
1515

1616
# imports
1717
# -------

src/broadcast.jl

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,29 +31,11 @@ function Base.Broadcast.materialize(a::SectorArray)
3131
return ofsector(a, Base.Broadcast.materialize(a.data))
3232
end
3333

34-
function check_sector_broadcast_axes(a::SectorArray, b::SectorArray)
35-
axes(a) == axes(b) ||
36-
throw(ArgumentError("SectorArray linear broadcasting requires matching axes"))
37-
return nothing
38-
end
39-
4034
function TensorAlgebra.:+(a::SectorArray, b::SectorArray)
41-
check_sector_broadcast_axes(a, b)
35+
_check_add_axes(a, b)
4236
return ofsector(a, a.data +ₗ b.data)
4337
end
4438

45-
function TensorAlgebra.add!(dest::AbstractArray, src::SectorArray, α::Number, β::Number)
46-
require_unique_fusion(src)
47-
TensorAlgebra.add!(dest, src.data, α, β)
48-
return dest
49-
end
50-
51-
function TensorAlgebra.add!(dest::SectorArray, src::SectorArray, α::Number, β::Number)
52-
check_sector_broadcast_axes(dest, src)
53-
TensorAlgebra.add!(dest.data, src.data, α, β)
54-
return dest
55-
end
56-
5739
function TensorAlgebra.:*::Number, a::SectorArray)
5840
return ofsector(a, α *ₗ a.data)
5941
end
@@ -108,10 +90,10 @@ function Base.similar(bc::BC.Broadcasted{<:GradedStyle}, elt::Type, ax)
10890
return graded_similar(arg, elt, ax)
10991
end
11092

111-
function check_graded_broadcast_axes(a::AbstractArray, b::AbstractArray)
112-
all(dim -> space_isequal(axes(a, dim), axes(b, dim)), 1:ndims(a)) ||
93+
function _check_add_axes(a::AbstractArray, b::AbstractArray)
94+
axes(a) == axes(b) ||
11395
throw(
114-
ArgumentError("GradedArray linear broadcasting requires matching graded axes")
96+
ArgumentError("linear broadcasting requires matching axes")
11597
)
11698
return nothing
11799
end
@@ -211,7 +193,7 @@ function copy_lazygraded(a::LazyGradedArray)
211193
end
212194

213195
function TensorAlgebra.:+(a::LazyGradedArray, b::LazyGradedArray)
214-
check_graded_broadcast_axes(a, b)
196+
_check_add_axes(a, b)
215197
return AddGradedArray(a, b)
216198
end
217199
TensorAlgebra.:*::Number, a::GradedArray) = ScaledGradedArray(α, a)

src/gradedarray.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,16 @@ function KroneckerArrays.:×(g1::GradedUnitRange, g2::GradedUnitRange)
4545
return mortar_axis(v)
4646
end
4747

48-
function space_isequal(a1::AbstractUnitRange, a2::AbstractUnitRange)
49-
return (isdual(a1) == isdual(a2)) && sectors(a1) == sectors(a2) && blockisequal(a1, a2)
48+
function Base.isequal(a::GradedUnitRange, b::GradedUnitRange)
49+
ea = eachblockaxis(a)
50+
eb = eachblockaxis(b)
51+
return length(ea) == length(eb) && all(splat(isequal), zip(ea, eb))
5052
end
53+
Base.:(==)(a::GradedUnitRange, b::GradedUnitRange) = isequal(a, b)
54+
Base.:(==)(::GradedUnitRange, ::AbstractUnitRange) = false
55+
Base.:(==)(::AbstractUnitRange, ::GradedUnitRange) = false
56+
Base.isequal(::GradedUnitRange, ::AbstractUnitRange) = false
57+
Base.isequal(::AbstractUnitRange, ::GradedUnitRange) = false
5158

5259
function BlockSparseArrays.blockrange(xs::Vector{<:GradedUnitRange})
5360
baxis = mapreduce(eachblockaxis, vcat, xs)
@@ -123,6 +130,21 @@ const GradedArray{T, N, I, A, Blocks, Axes <: NTuple{N, GradedUnitRange{I}}} =
123130
const GradedMatrix{T, I, A, Blocks, Axes} = GradedArray{T, 2, A, Blocks, Axes}
124131
const GradedVector{T, I, A, Blocks, Axes} = GradedArray{T, 1, A, Blocks, Axes}
125132

133+
# Override ArrayLayouts._check_mul_axes for GradedArray.
134+
# For graded matrices, the contracted axes satisfy axes(A,2) == dual(axes(B,1)),
135+
# not axes(A,2) == axes(B,1), so the default check needs to account for duality.
136+
const GradedMatrixOrAdj = Union{
137+
GradedMatrix, LinearAlgebra.AdjOrTrans{<:Any, <:GradedMatrix},
138+
}
139+
function ArrayLayouts._check_mul_axes(A::GradedMatrixOrAdj, B::GradedMatrixOrAdj)
140+
axes(A, 2) == dual(axes(B, 1)) || throw(
141+
DimensionMismatch(
142+
"second axis of A, $(axes(A, 2)), and first axis of B, $(axes(B, 1)), must match"
143+
)
144+
)
145+
return nothing
146+
end
147+
126148
# Specific overloads
127149
# ------------------
128150
# convert Array to SectorArray upon insertion

src/sectorarray.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ end
3636
KroneckerArrays.:×(a::SectorRange, g::AbstractUnitRange) = cartesianrange(a, g)
3737
KroneckerArrays.:×(g::AbstractUnitRange, a::SectorRange) = cartesianrange(a, g)
3838

39+
function Base.isequal(a::SectorUnitRange, b::SectorUnitRange)
40+
return isequal(kroneckerfactors(a), kroneckerfactors(b))
41+
end
42+
function Base.:(==)(a::SectorUnitRange, b::SectorUnitRange)
43+
return isequal(a, b)
44+
end
45+
3946
to_gradedrange(g::SectorUnitRange) = mortar_axis([g])
4047

4148
"""
@@ -385,10 +392,10 @@ end
385392
# TODO: Define this as part of:
386393
# `check_input(::typeof(mul!), ::SectorMatrix, ::SectorMatrix, ::SectorMatrix)`
387394
function check_mul_axes(c::SectorMatrix, a::SectorMatrix, b::SectorMatrix)
388-
space_isequal(axes(a, 2), dual(axes(b, 1))) ||
395+
axes(a, 2) == dual(axes(b, 1)) ||
389396
throw(DimensionMismatch("$(axes(a, 2)) != dual($(axes(b, 1))))"))
390-
space_isequal(axes(c, 1), axes(a, 1)) || throw(DimensionMismatch())
391-
space_isequal(axes(c, 2), axes(b, 2)) || throw(DimensionMismatch())
397+
axes(c, 1) == axes(a, 1) || throw(DimensionMismatch())
398+
axes(c, 2) == axes(b, 2) || throw(DimensionMismatch())
392399
return nothing
393400
end
394401

@@ -413,6 +420,18 @@ function KroneckerArrays.:(⊗)(
413420
return SectorArray(A.sectors, collect(T, data))
414421
end
415422

423+
function TensorAlgebra.add!(dest::AbstractArray, src::SectorArray, α::Number, β::Number)
424+
require_unique_fusion(src)
425+
TensorAlgebra.add!(dest, src.data, α, β)
426+
return dest
427+
end
428+
429+
function TensorAlgebra.add!(dest::SectorArray, src::SectorArray, α::Number, β::Number)
430+
_check_add_axes(dest, src)
431+
TensorAlgebra.add!(dest.data, src.data, α, β)
432+
return dest
433+
end
434+
416435
# TODO: can we avoid this?
417436
function Base.materialize!(dst::SectorArray, src::KroneckerArrays.KroneckerBroadcasted)
418437
Base.materialize!(kroneckerfactors(dst, 1), kroneckerfactors(src, 1))

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ path = ".."
2525
Aqua = "0.8.11"
2626
BlockArrays = "1.6"
2727
BlockSparseArrays = "0.10"
28-
GradedArrays = "0.6"
28+
GradedArrays = "0.7"
2929
ITensorPkgSkeleton = "0.3.42"
3030
KroneckerArrays = "0.3"
3131
LinearAlgebra = "1.10"

test/test_exports.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ using Test: @test, @testset
3030
:sectorrange,
3131
:sectors,
3232
:sector_type,
33-
:space_isequal,
3433
:ungrade,
3534
]
3635
@test issetequal(names(GradedArrays), exports)

0 commit comments

Comments
 (0)