Skip to content

Commit b0bdc61

Browse files
mtfishmanclaude
andcommitted
Migrate from TensorAlgebra macro lazy types to LinearBroadcasted
- Delete ScaledGradedArray, ConjGradedArray, AddGradedArray and all macro-generated lazy types - Delete lazyblock, graded_eachblockstoredindex, copy_lazygraded - Replace BC.broadcasted(::GradedStyle/SectorStyle, ...) eager interception with copyto! instantiation-time conversion via tryflattenlinear - Add permutedimsopadd! overloads for GradedArray and SectorArray - Update permutedimsadd! → permutedimsopadd! with op parameter in tensoralgebra.jl - Replace y .*= β with blockwise scaling to avoid broadcasting cycle - Bump TensorAlgebra compat to 0.7.21, 0.8 - Bump version to 0.6.24 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 76cc220 commit b0bdc61

File tree

6 files changed

+39
-168
lines changed

6 files changed

+39
-168
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "GradedArrays"
22
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
3-
version = "0.7.0"
3+
version = "0.7.1"
44
authors = ["ITensor developers <support@itensor.org> and contributors"]
55

66
[workspace]
@@ -48,7 +48,7 @@ Random = "1.10"
4848
SUNRepresentations = "0.3"
4949
SparseArraysBase = "0.9"
5050
SplitApplyCombine = "1.2.3"
51-
TensorAlgebra = "0.7.20"
51+
TensorAlgebra = "0.7.21, 0.8"
5252
TensorKitSectors = "0.3"
5353
TypeParameterAccessors = "0.4"
5454
julia = "1.10"

src/broadcast.jl

Lines changed: 15 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using Base.Broadcast: Broadcast as BC
2-
using FillArrays: Zeros, fillsimilar
3-
using TensorAlgebra: TensorAlgebra, *ₗ, +ₗ, -ₗ, /ₗ, conjed
2+
using TensorAlgebra: TensorAlgebra
43

54
struct SectorStyle{I, N} <: BC.AbstractArrayStyle{N} end
65
SectorStyle{I, N}(::Val{M}) where {I, N, M} = SectorStyle{I, M}()
@@ -31,21 +30,17 @@ function Base.Broadcast.materialize(a::SectorArray)
3130
return ofsector(a, Base.Broadcast.materialize(a.data))
3231
end
3332

34-
function TensorAlgebra.:+(a::SectorArray, b::SectorArray)
35-
_check_add_axes(a, b)
36-
return ofsector(a, a.data +ₗ b.data)
37-
end
38-
39-
function TensorAlgebra.:*::Number, a::SectorArray)
40-
return ofsector(a, α *ₗ a.data)
41-
end
42-
TensorAlgebra.:*(a::SectorArray, α::Number) = α *ₗ a
43-
function TensorAlgebra.conjed(a::SectorArray)
44-
return ofsector(a, TensorAlgebra.conjed(a.data))
33+
function Base.similar(bc::BC.Broadcasted{<:SectorStyle}, elt::Type, ax)
34+
bc′ = BC.flatten(bc)
35+
arg = bc′.args[findfirst(arg -> arg isa SectorArray, bc′.args)]
36+
return ofsector(arg, similar(arg.data, elt))
4537
end
4638

47-
function BC.broadcasted(style::SectorStyle, f, args...)
48-
return TensorAlgebra.broadcasted_linear(style, f, args...)
39+
function Base.copyto!(dest::SectorArray, bc::BC.Broadcasted{<:SectorStyle})
40+
lb = TensorAlgebra.tryflattenlinear(bc)
41+
isnothing(lb) &&
42+
throw(ArgumentError("SectorArray broadcasting requires linear operations"))
43+
return copyto!(dest, lb)
4944
end
5045

5146
struct GradedStyle{I, N, B <: BC.AbstractArrayStyle{N}} <: BC.AbstractArrayStyle{N}
@@ -90,67 +85,6 @@ function Base.similar(bc::BC.Broadcasted{<:GradedStyle}, elt::Type, ax)
9085
return graded_similar(arg, elt, ax)
9186
end
9287

93-
function _check_add_axes(a::AbstractArray, b::AbstractArray)
94-
axes(a) == axes(b) ||
95-
throw(
96-
ArgumentError("linear broadcasting requires matching axes")
97-
)
98-
return nothing
99-
end
100-
101-
function lazyblock(a::GradedArray{<:Any, N}, I::Vararg{Block{1}, N}) where {N}
102-
if isstored(a, I...)
103-
return blocks(a)[Int.(I)...]
104-
else
105-
block_ax = map((ax, i) -> eachblockaxis(ax)[Int(i)], axes(a), I)
106-
return fillsimilar(Zeros{eltype(a)}(block_ax), block_ax)
107-
end
108-
end
109-
lazyblock(a::GradedArray, I::Block) = lazyblock(a, Tuple(I)...)
110-
111-
TensorAlgebra.@scaledarray_type ScaledGradedArray
112-
TensorAlgebra.@scaledarray ScaledGradedArray
113-
TensorAlgebra.@conjarray_type ConjGradedArray
114-
TensorAlgebra.@conjarray ConjGradedArray
115-
TensorAlgebra.@addarray_type AddGradedArray
116-
TensorAlgebra.@addarray AddGradedArray
117-
118-
const LazyGradedArray = Union{
119-
GradedArray, ScaledGradedArray, ConjGradedArray, AddGradedArray,
120-
}
121-
122-
function TensorAlgebra.BroadcastStyle_scaled(arrayt::Type{<:ScaledGradedArray})
123-
return BC.BroadcastStyle(TensorAlgebra.unscaled_type(arrayt))
124-
end
125-
function TensorAlgebra.BroadcastStyle_conj(arrayt::Type{<:ConjGradedArray})
126-
return BC.BroadcastStyle(TensorAlgebra.conjed_type(arrayt))
127-
end
128-
function TensorAlgebra.BroadcastStyle_add(arrayt::Type{<:AddGradedArray})
129-
args_type = TensorAlgebra.addends_type(arrayt)
130-
return Base.promote_op(BC.combine_styles, fieldtypes(args_type)...)()
131-
end
132-
133-
function lazyblock(a::ScaledGradedArray, I::Block)
134-
return TensorAlgebra.coeff(a) *lazyblock(TensorAlgebra.unscaled(a), I)
135-
end
136-
function lazyblock(a::ConjGradedArray, I::Block)
137-
return conjed(lazyblock(conjed(a), I))
138-
end
139-
function lazyblock(a::AddGradedArray, I::Block)
140-
return +(map(Base.Fix2(lazyblock, I), TensorAlgebra.addends(a))...)
141-
end
142-
143-
# TODO: Use `eachblockstoredindex` directly for lazy graded wrappers and delete the
144-
# `graded_eachblockstoredindex` helper once that refactor is split into its own PR.
145-
graded_eachblockstoredindex(a::GradedArray) = collect(eachblockstoredindex(a))
146-
function graded_eachblockstoredindex(a::ScaledGradedArray)
147-
return graded_eachblockstoredindex(TensorAlgebra.unscaled(a))
148-
end
149-
graded_eachblockstoredindex(a::ConjGradedArray) = graded_eachblockstoredindex(conjed(a))
150-
function graded_eachblockstoredindex(a::AddGradedArray)
151-
return unique!(vcat(map(graded_eachblockstoredindex, TensorAlgebra.addends(a))...))
152-
end
153-
15488
# TODO: Rename `graded_similar` to `similar_graded` or fold it into `similar`
15589
# entirely once the follow-up allocator cleanup is ready.
15690
function graded_similar(
@@ -160,52 +94,10 @@ function graded_similar(
16094
) where {N}
16195
return similar(a, elt, ax)
16296
end
163-
function graded_similar(
164-
a::ScaledGradedArray,
165-
elt::Type,
166-
ax::NTuple{N, <:GradedUnitRange}
167-
) where {N}
168-
return graded_similar(TensorAlgebra.unscaled(a), elt, ax)
169-
end
170-
function graded_similar(
171-
a::ConjGradedArray,
172-
elt::Type,
173-
ax::NTuple{N, <:GradedUnitRange}
174-
) where {N}
175-
return graded_similar(conjed(a), elt, ax)
176-
end
177-
function graded_similar(
178-
a::AddGradedArray,
179-
elt::Type,
180-
ax::NTuple{N, <:GradedUnitRange}
181-
) where {N}
182-
style = BC.combine_styles(TensorAlgebra.addends(a)...)
183-
bc = BC.Broadcasted(style, +, TensorAlgebra.addends(a))
184-
return similar(bc, elt, ax)
185-
end
186-
187-
function copy_lazygraded(a::LazyGradedArray)
188-
c = graded_similar(a, eltype(a), axes(a))
189-
for I in graded_eachblockstoredindex(a)
190-
c[I] = lazyblock(a, I)
191-
end
192-
return c
193-
end
194-
195-
function TensorAlgebra.:+(a::LazyGradedArray, b::LazyGradedArray)
196-
_check_add_axes(a, b)
197-
return AddGradedArray(a, b)
198-
end
199-
TensorAlgebra.:*::Number, a::GradedArray) = ScaledGradedArray(α, a)
200-
TensorAlgebra.conjed(a::GradedArray) = ConjGradedArray(a)
201-
202-
Base.copy(a::ScaledGradedArray) = copy_lazygraded(a)
203-
Base.copy(a::ConjGradedArray) = copy_lazygraded(a)
204-
Base.copy(a::AddGradedArray) = copy_lazygraded(a)
205-
Base.Array(a::ScaledGradedArray) = Array(copy(a))
206-
Base.Array(a::ConjGradedArray) = Array(copy(a))
207-
Base.Array(a::AddGradedArray) = Array(copy(a))
20897

209-
function BC.broadcasted(style::GradedStyle, f, args...)
210-
return TensorAlgebra.broadcasted_linear(style, f, args...)
98+
function Base.copyto!(dest::GradedArray, bc::BC.Broadcasted{<:GradedStyle})
99+
lb = TensorAlgebra.tryflattenlinear(bc)
100+
isnothing(lb) &&
101+
throw(ArgumentError("GradedArray broadcasting requires linear operations"))
102+
return copyto!(dest, lb)
211103
end

src/tensoralgebra.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,27 +155,32 @@ function TensorAlgebra.unmatricize(
155155
return SectorArray(msectors.sectors, mdata)
156156
end
157157

158-
function TensorAlgebra.permutedimsadd!(
159-
y::SectorArray, x::SectorArray, perm,
158+
function TensorAlgebra.permutedimsopadd!(
159+
y::SectorArray, op, x::SectorArray, perm,
160160
α::Number, β::Number
161161
)
162162
ysectors, ydata = kroneckerfactors(y)
163163
xsectors, xdata = kroneckerfactors(x)
164164
ysectors == permutedims(xsectors, perm) || throw(DimensionMismatch())
165165
phase = fermion_permutation_phase(xsectors, perm)
166-
TensorAlgebra.permutedimsadd!(ydata, xdata, perm, phase * α, β)
166+
TensorAlgebra.permutedimsopadd!(ydata, op, xdata, perm, phase * α, β)
167167
return y
168168
end
169-
function TensorAlgebra.permutedimsadd!(
170-
y::GradedArray{<:Any, N}, x::GradedArray{<:Any, N}, perm,
169+
function TensorAlgebra.permutedimsopadd!(
170+
y::GradedArray{<:Any, N}, op, x::GradedArray{<:Any, N}, perm,
171171
α::Number, β::Number
172172
) where {N}
173-
y .*= β
173+
if !iszero(β)
174+
for bI in eachblockstoredindex(y)
175+
block = blocks(y)[Int.(Tuple(bI))...]
176+
idperm = ntuple(identity, ndims(block))
177+
TensorAlgebra.permutedimsopadd!(block, identity, block, idperm, β, false)
178+
end
179+
end
174180
for bI in eachblockstoredindex(x)
175181
b = Tuple(bI)
176182
b_dest = ntuple(i -> b[perm[i]], N)
177-
y[Block(b_dest)] =
178-
TensorAlgebra.permutedimsadd!(y[Block(b_dest)], x[bI], perm, α, true)
183+
TensorAlgebra.permutedimsopadd!(y[Block(b_dest)], op, x[bI], perm, α, true)
179184
end
180185
return y
181186
end

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ SUNRepresentations = "0.3"
3636
SafeTestsets = "0.1"
3737
SparseArraysBase = "0.9"
3838
Suppressor = "0.2.8"
39-
TensorAlgebra = "0.7.19"
39+
TensorAlgebra = "0.7.21, 0.8"
4040
TensorKitSectors = "0.3"
4141
Test = "1.10"
4242
TestExtras = "0.3.1"

test/test_gradedarray.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using KroneckerArrays: cartesianrange
99
using LinearAlgebra: adjoint
1010
using Random: randn!
1111
using SparseArraysBase: storedlength
12-
using TensorAlgebra: TensorAlgebra, *ₗ, +ₗ, -ₗ, /ₗ, conjed
12+
using TensorAlgebra: TensorAlgebra, linearbroadcasted
1313
using Test: @test, @test_broken, @test_throws, @testset
1414

1515
function randn_blockdiagonal(elt::Type, axes::Tuple)
@@ -403,17 +403,17 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
403403
@test Array(C) α .* Array(A) .+ β .* Array(B)
404404
@test axes(C) == axes(A)
405405
@test all(dim -> isdual(axes(C, dim)) == isdual(axes(A, dim)), 1:ndims(A))
406-
Cₗ = α *ₗ A +ₗ β *ₗ B
406+
Cₗ = linearbroadcasted(+, linearbroadcasted(*, α, A), linearbroadcasted(*, β, B))
407407
@test TensorAlgebra.iscall(Cₗ)
408-
@test Array(Cₗ) α .* Array(A) .+ β .* Array(B)
408+
@test Array(copy(Cₗ)) α .* Array(A) .+ β .* Array(B)
409409
@test axes(Cₗ) == axes(A)
410410

411411
D = conj.(A) .- B ./ β
412412
@test Array(D) conj.(Array(A)) .- Array(B) ./ β
413413
@test axes(D) == axes(A)
414-
Dₗ = conjed(A) -ₗ (B /ₗ β)
414+
Dₗ = linearbroadcasted(+, linearbroadcasted(conj, A), linearbroadcasted(*, -1/β, B))
415415
@test TensorAlgebra.iscall(Dₗ)
416-
@test Array(Dₗ) conj.(Array(A)) .- Array(B) ./ β
416+
@test Array(copy(Dₗ)) conj.(Array(A)) .- Array(B) ./ β
417417
@test axes(Dₗ) == axes(A)
418418

419419
@test_throws ArgumentError A .* B

test/test_tensoralgebraext.jl

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using GradedArrays: GradedArray, GradedMatrix, SU2, SectorArray, SectorDelta, U1
44
flip, gradedrange, isdual, sector, sector_type, sectorrange, trivial,
55
trivial_gradedrange,
66
using Random: randn!
7-
using TensorAlgebra: TensorAlgebra, *ₗ, +ₗ, -ₗ, /ₗ, FusionStyle, conjed, contract,
7+
using TensorAlgebra: TensorAlgebra, FusionStyle, contract,
88
matricize, tensor_product_axis, trivial_axis, unmatricize
99
using Test: @test, @test_throws, @testset
1010

@@ -76,38 +76,20 @@ end
7676
@test st.data isa Matrix
7777
@test Array(st) α .* Array(s) .+ β .* Array(t)
7878
@test axes(st) == axes(s)
79-
@test*ₗ s) isa SectorArray
80-
@test TensorAlgebra.iscall((α *ₗ s).data)
81-
@test*ₗ s +ₗ β *ₗ t) isa SectorArray
82-
@test TensorAlgebra.iscall((α *ₗ s +ₗ β *ₗ t).data)
83-
@test Array*ₗ s +ₗ β *ₗ t) α .* Array(s) .+ β .* Array(t)
84-
@test axes*ₗ s +ₗ β *ₗ t) == axes(s)
85-
@test Base.broadcasted(*, α, s) isa SectorArray
86-
@test TensorAlgebra.iscall(Base.broadcasted(*, α, s).data)
87-
8879
conjdiff = conj.(s) .- t ./ β
8980
@test conjdiff isa SectorArray
9081
@test conjdiff.data isa Matrix
9182
@test Array(conjdiff) conj.(Array(s)) .- Array(t) ./ β
9283
@test axes(conjdiff) == axes(s)
93-
@test conjed(s) isa SectorArray
94-
@test TensorAlgebra.iscall(conjed(s).data)
95-
@test (conjed(s) -ₗ (t /ₗ β)) isa SectorArray
96-
@test TensorAlgebra.iscall((conjed(s) -ₗ (t /ₗ β)).data)
97-
@test Array(conjed(s) -ₗ (t /ₗ β)) conj.(Array(s)) .- Array(t) ./ β
98-
@test axes(conjed(s) -ₗ (t /ₗ β)) == axes(s)
9984

10085
@test_throws ArgumentError s .* t
10186
@test_throws ArgumentError exp.(s)
10287
end
10388

104-
@testset "SectorArray scalar multiplication materializes on broadcast materialize" begin
89+
@testset "SectorArray scalar multiplication materializes on broadcast" begin
10590
s = SectorArray((U1(0), dual(U1(0))), randn!(Matrix{Float64}(undef, 2, 2)))
10691

107-
scaled = Base.broadcasted(*, 2, s)
108-
@test scaled isa SectorArray
109-
@test TensorAlgebra.iscall(scaled.data)
110-
materialized = Base.Broadcast.materialize(scaled)
92+
materialized = 2 .* s
11193
@test materialized isa SectorArray
11294
@test materialized.data isa Matrix
11395
@test materialized[1, 1] == 2 * s[1, 1]
@@ -120,14 +102,6 @@ end
120102
@test Array(scaled_mul) 2 .* Array(s)
121103
end
122104

123-
@testset "SectorArray lazy display" begin
124-
s = SectorArray((U1(0), dual(U1(0))), randn!(Matrix{Float64}(undef, 2, 2)))
125-
lazy = 2 *ₗ s
126-
shown = sprint(show, MIME("text/plain"), lazy)
127-
@test contains(shown, "SectorMatrix")
128-
@test contains(shown, "")
129-
end
130-
131105
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
132106
@testset "`contract` `GradedArray` (eltype=$elt)" for elt in elts
133107
@testset "matricize" begin

0 commit comments

Comments
 (0)