Skip to content

Commit b76db01

Browse files
mtfishmanclaude
andauthored
Migrate from TensorAlgebra macro lazy types to LinearBroadcasted (#139)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 76cc220 commit b76db01

File tree

6 files changed

+50
-173
lines changed

6 files changed

+50
-173
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "GradedArrays"
22
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
3-
version = "0.7.0"
3+
version = "0.7.1"
44
authors = ["ITensor developers <support@itensor.org> and contributors"]
55

66
[workspace]
@@ -43,12 +43,12 @@ HalfIntegers = "1.6"
4343
KroneckerArrays = "0.3.27"
4444
LinearAlgebra = "1.10"
4545
MatrixAlgebraKit = "0.6"
46-
NamedDimsArrays = "0.13, 0.14"
46+
NamedDimsArrays = "0.15"
4747
Random = "1.10"
4848
SUNRepresentations = "0.3"
4949
SparseArraysBase = "0.9"
5050
SplitApplyCombine = "1.2.3"
51-
TensorAlgebra = "0.7.20"
51+
TensorAlgebra = "0.8"
5252
TensorKitSectors = "0.3"
5353
TypeParameterAccessors = "0.4"
5454
julia = "1.10"

src/broadcast.jl

Lines changed: 15 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using Base.Broadcast: Broadcast as BC
2-
using FillArrays: Zeros, fillsimilar
3-
using TensorAlgebra: TensorAlgebra, *ₗ, +ₗ, -ₗ, /ₗ, conjed
2+
using TensorAlgebra: TensorAlgebra
43

54
struct SectorStyle{I, N} <: BC.AbstractArrayStyle{N} end
65
SectorStyle{I, N}(::Val{M}) where {I, N, M} = SectorStyle{I, M}()
@@ -31,21 +30,17 @@ function Base.Broadcast.materialize(a::SectorArray)
3130
return ofsector(a, Base.Broadcast.materialize(a.data))
3231
end
3332

34-
function TensorAlgebra.:+(a::SectorArray, b::SectorArray)
35-
_check_add_axes(a, b)
36-
return ofsector(a, a.data +ₗ b.data)
37-
end
38-
39-
function TensorAlgebra.:*::Number, a::SectorArray)
40-
return ofsector(a, α *ₗ a.data)
41-
end
42-
TensorAlgebra.:*(a::SectorArray, α::Number) = α *ₗ a
43-
function TensorAlgebra.conjed(a::SectorArray)
44-
return ofsector(a, TensorAlgebra.conjed(a.data))
33+
function Base.similar(bc::BC.Broadcasted{<:SectorStyle}, elt::Type, ax)
34+
bc′ = BC.flatten(bc)
35+
arg = bc′.args[findfirst(arg -> arg isa SectorArray, bc′.args)]
36+
return ofsector(arg, similar(arg.data, elt))
4537
end
4638

47-
function BC.broadcasted(style::SectorStyle, f, args...)
48-
return TensorAlgebra.broadcasted_linear(style, f, args...)
39+
function Base.copyto!(dest::SectorArray, bc::BC.Broadcasted{<:SectorStyle})
40+
lb = TensorAlgebra.tryflattenlinear(bc)
41+
isnothing(lb) &&
42+
throw(ArgumentError("SectorArray broadcasting requires linear operations"))
43+
return copyto!(dest, lb)
4944
end
5045

5146
struct GradedStyle{I, N, B <: BC.AbstractArrayStyle{N}} <: BC.AbstractArrayStyle{N}
@@ -90,67 +85,6 @@ function Base.similar(bc::BC.Broadcasted{<:GradedStyle}, elt::Type, ax)
9085
return graded_similar(arg, elt, ax)
9186
end
9287

93-
function _check_add_axes(a::AbstractArray, b::AbstractArray)
94-
axes(a) == axes(b) ||
95-
throw(
96-
ArgumentError("linear broadcasting requires matching axes")
97-
)
98-
return nothing
99-
end
100-
101-
function lazyblock(a::GradedArray{<:Any, N}, I::Vararg{Block{1}, N}) where {N}
102-
if isstored(a, I...)
103-
return blocks(a)[Int.(I)...]
104-
else
105-
block_ax = map((ax, i) -> eachblockaxis(ax)[Int(i)], axes(a), I)
106-
return fillsimilar(Zeros{eltype(a)}(block_ax), block_ax)
107-
end
108-
end
109-
lazyblock(a::GradedArray, I::Block) = lazyblock(a, Tuple(I)...)
110-
111-
TensorAlgebra.@scaledarray_type ScaledGradedArray
112-
TensorAlgebra.@scaledarray ScaledGradedArray
113-
TensorAlgebra.@conjarray_type ConjGradedArray
114-
TensorAlgebra.@conjarray ConjGradedArray
115-
TensorAlgebra.@addarray_type AddGradedArray
116-
TensorAlgebra.@addarray AddGradedArray
117-
118-
const LazyGradedArray = Union{
119-
GradedArray, ScaledGradedArray, ConjGradedArray, AddGradedArray,
120-
}
121-
122-
function TensorAlgebra.BroadcastStyle_scaled(arrayt::Type{<:ScaledGradedArray})
123-
return BC.BroadcastStyle(TensorAlgebra.unscaled_type(arrayt))
124-
end
125-
function TensorAlgebra.BroadcastStyle_conj(arrayt::Type{<:ConjGradedArray})
126-
return BC.BroadcastStyle(TensorAlgebra.conjed_type(arrayt))
127-
end
128-
function TensorAlgebra.BroadcastStyle_add(arrayt::Type{<:AddGradedArray})
129-
args_type = TensorAlgebra.addends_type(arrayt)
130-
return Base.promote_op(BC.combine_styles, fieldtypes(args_type)...)()
131-
end
132-
133-
function lazyblock(a::ScaledGradedArray, I::Block)
134-
return TensorAlgebra.coeff(a) *lazyblock(TensorAlgebra.unscaled(a), I)
135-
end
136-
function lazyblock(a::ConjGradedArray, I::Block)
137-
return conjed(lazyblock(conjed(a), I))
138-
end
139-
function lazyblock(a::AddGradedArray, I::Block)
140-
return +(map(Base.Fix2(lazyblock, I), TensorAlgebra.addends(a))...)
141-
end
142-
143-
# TODO: Use `eachblockstoredindex` directly for lazy graded wrappers and delete the
144-
# `graded_eachblockstoredindex` helper once that refactor is split into its own PR.
145-
graded_eachblockstoredindex(a::GradedArray) = collect(eachblockstoredindex(a))
146-
function graded_eachblockstoredindex(a::ScaledGradedArray)
147-
return graded_eachblockstoredindex(TensorAlgebra.unscaled(a))
148-
end
149-
graded_eachblockstoredindex(a::ConjGradedArray) = graded_eachblockstoredindex(conjed(a))
150-
function graded_eachblockstoredindex(a::AddGradedArray)
151-
return unique!(vcat(map(graded_eachblockstoredindex, TensorAlgebra.addends(a))...))
152-
end
153-
15488
# TODO: Rename `graded_similar` to `similar_graded` or fold it into `similar`
15589
# entirely once the follow-up allocator cleanup is ready.
15690
function graded_similar(
@@ -160,52 +94,10 @@ function graded_similar(
16094
) where {N}
16195
return similar(a, elt, ax)
16296
end
163-
function graded_similar(
164-
a::ScaledGradedArray,
165-
elt::Type,
166-
ax::NTuple{N, <:GradedUnitRange}
167-
) where {N}
168-
return graded_similar(TensorAlgebra.unscaled(a), elt, ax)
169-
end
170-
function graded_similar(
171-
a::ConjGradedArray,
172-
elt::Type,
173-
ax::NTuple{N, <:GradedUnitRange}
174-
) where {N}
175-
return graded_similar(conjed(a), elt, ax)
176-
end
177-
function graded_similar(
178-
a::AddGradedArray,
179-
elt::Type,
180-
ax::NTuple{N, <:GradedUnitRange}
181-
) where {N}
182-
style = BC.combine_styles(TensorAlgebra.addends(a)...)
183-
bc = BC.Broadcasted(style, +, TensorAlgebra.addends(a))
184-
return similar(bc, elt, ax)
185-
end
186-
187-
function copy_lazygraded(a::LazyGradedArray)
188-
c = graded_similar(a, eltype(a), axes(a))
189-
for I in graded_eachblockstoredindex(a)
190-
c[I] = lazyblock(a, I)
191-
end
192-
return c
193-
end
194-
195-
function TensorAlgebra.:+(a::LazyGradedArray, b::LazyGradedArray)
196-
_check_add_axes(a, b)
197-
return AddGradedArray(a, b)
198-
end
199-
TensorAlgebra.:*::Number, a::GradedArray) = ScaledGradedArray(α, a)
200-
TensorAlgebra.conjed(a::GradedArray) = ConjGradedArray(a)
201-
202-
Base.copy(a::ScaledGradedArray) = copy_lazygraded(a)
203-
Base.copy(a::ConjGradedArray) = copy_lazygraded(a)
204-
Base.copy(a::AddGradedArray) = copy_lazygraded(a)
205-
Base.Array(a::ScaledGradedArray) = Array(copy(a))
206-
Base.Array(a::ConjGradedArray) = Array(copy(a))
207-
Base.Array(a::AddGradedArray) = Array(copy(a))
20897

209-
function BC.broadcasted(style::GradedStyle, f, args...)
210-
return TensorAlgebra.broadcasted_linear(style, f, args...)
98+
function Base.copyto!(dest::GradedArray, bc::BC.Broadcasted{<:GradedStyle})
99+
lb = TensorAlgebra.tryflattenlinear(bc)
100+
isnothing(lb) &&
101+
throw(ArgumentError("GradedArray broadcasting requires linear operations"))
102+
return copyto!(dest, lb)
211103
end

src/tensoralgebra.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,27 +155,34 @@ function TensorAlgebra.unmatricize(
155155
return SectorArray(msectors.sectors, mdata)
156156
end
157157

158-
function TensorAlgebra.permutedimsadd!(
159-
y::SectorArray, x::SectorArray, perm,
158+
function TensorAlgebra.permutedimsopadd!(
159+
y::SectorArray, op, x::SectorArray, perm,
160160
α::Number, β::Number
161161
)
162162
ysectors, ydata = kroneckerfactors(y)
163163
xsectors, xdata = kroneckerfactors(x)
164164
ysectors == permutedims(xsectors, perm) || throw(DimensionMismatch())
165165
phase = fermion_permutation_phase(xsectors, perm)
166-
TensorAlgebra.permutedimsadd!(ydata, xdata, perm, phase * α, β)
166+
TensorAlgebra.permutedimsopadd!(ydata, op, xdata, perm, phase * α, β)
167167
return y
168168
end
169-
function TensorAlgebra.permutedimsadd!(
170-
y::GradedArray{<:Any, N}, x::GradedArray{<:Any, N}, perm,
169+
function TensorAlgebra.permutedimsopadd!(
170+
y::GradedArray{<:Any, N}, op, x::GradedArray{<:Any, N}, perm,
171171
α::Number, β::Number
172172
) where {N}
173-
y .*= β
173+
if !iszero(β)
174+
for bI in eachblockstoredindex(y)
175+
y_b = @view!(y[bI])
176+
idperm = ntuple(identity, ndims(y_b))
177+
TensorAlgebra.permutedimsopadd!(y_b, identity, y_b, idperm, β, false)
178+
end
179+
end
174180
for bI in eachblockstoredindex(x)
175181
b = Tuple(bI)
176-
b_dest = ntuple(i -> b[perm[i]], N)
177-
y[Block(b_dest)] =
178-
TensorAlgebra.permutedimsadd!(y[Block(b_dest)], x[bI], perm, α, true)
182+
b_dest = Block(ntuple(i -> b[perm[i]], N))
183+
y_b = @view!(y[b_dest])
184+
x_b = @view!(x[bI])
185+
TensorAlgebra.permutedimsopadd!(y_b, op, x_b, perm, α, true)
179186
end
180187
return y
181188
end

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ ITensorPkgSkeleton = "0.3.42"
3030
KroneckerArrays = "0.3"
3131
LinearAlgebra = "1.10"
3232
MatrixAlgebraKit = "0.6"
33-
NamedDimsArrays = "0.13, 0.14"
33+
NamedDimsArrays = "0.15"
3434
Random = "1.10"
3535
SUNRepresentations = "0.3"
3636
SafeTestsets = "0.1"
3737
SparseArraysBase = "0.9"
3838
Suppressor = "0.2.8"
39-
TensorAlgebra = "0.7.19"
39+
TensorAlgebra = "0.8"
4040
TensorKitSectors = "0.3"
4141
Test = "1.10"
4242
TestExtras = "0.3.1"

test/test_gradedarray.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using KroneckerArrays: cartesianrange
99
using LinearAlgebra: adjoint
1010
using Random: randn!
1111
using SparseArraysBase: storedlength
12-
using TensorAlgebra: TensorAlgebra, *ₗ, +ₗ, -ₗ, /ₗ, conjed
12+
using TensorAlgebra: TensorAlgebra, linearbroadcasted
1313
using Test: @test, @test_broken, @test_throws, @testset
1414

1515
function randn_blockdiagonal(elt::Type, axes::Tuple)
@@ -403,24 +403,28 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
403403
@test Array(C) α .* Array(A) .+ β .* Array(B)
404404
@test axes(C) == axes(A)
405405
@test all(dim -> isdual(axes(C, dim)) == isdual(axes(A, dim)), 1:ndims(A))
406-
Cₗ = α *ₗ A +ₗ β *ₗ B
406+
Cₗ = linearbroadcasted(+, linearbroadcasted(*, α, A), linearbroadcasted(*, β, B))
407407
@test TensorAlgebra.iscall(Cₗ)
408-
@test Array(Cₗ) α .* Array(A) .+ β .* Array(B)
408+
@test Array(copy(Cₗ)) α .* Array(A) .+ β .* Array(B)
409409
@test axes(Cₗ) == axes(A)
410410

411411
D = conj.(A) .- B ./ β
412412
@test Array(D) conj.(Array(A)) .- Array(B) ./ β
413413
@test axes(D) == axes(A)
414-
Dₗ = conjed(A) -ₗ (B /ₗ β)
414+
Dₗ = linearbroadcasted(
415+
+,
416+
linearbroadcasted(conj, A),
417+
linearbroadcasted(*, -1 / β, B)
418+
)
415419
@test TensorAlgebra.iscall(Dₗ)
416-
@test Array(Dₗ) conj.(Array(A)) .- Array(B) ./ β
420+
@test Array(copy(Dₗ)) conj.(Array(A)) .- Array(B) ./ β
417421
@test axes(Dₗ) == axes(A)
418422

419423
@test_throws ArgumentError A .* B
420424

421425
r_bad = gradedrange([U1(0) => 1, U1(1) => 3])
422426
B_bad = randn_blockdiagonal(elt, (r_bad, dual(r_bad)))
423-
@test_throws ArgumentError A .+ B_bad
427+
@test_throws DimensionMismatch A .+ B_bad
424428
end
425429
false && @testset "Construct from dense" begin
426430
r = gradedrange([U1(0) => 2, U1(1) => 3])

test/test_tensoralgebraext.jl

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ using GradedArrays: GradedArray, GradedMatrix, SU2, SectorArray, SectorDelta, U1
44
flip, gradedrange, isdual, sector, sector_type, sectorrange, trivial,
55
trivial_gradedrange,
66
using Random: randn!
7-
using TensorAlgebra: TensorAlgebra, *ₗ, +ₗ, -ₗ, /ₗ, FusionStyle, conjed, contract,
8-
matricize, tensor_product_axis, trivial_axis, unmatricize
7+
using TensorAlgebra: TensorAlgebra, FusionStyle, contract, matricize, tensor_product_axis,
8+
trivial_axis, unmatricize
99
using Test: @test, @test_throws, @testset
1010

1111
function randn_blockdiagonal(elt::Type, axes::Tuple)
@@ -76,38 +76,20 @@ end
7676
@test st.data isa Matrix
7777
@test Array(st) α .* Array(s) .+ β .* Array(t)
7878
@test axes(st) == axes(s)
79-
@test*ₗ s) isa SectorArray
80-
@test TensorAlgebra.iscall((α *ₗ s).data)
81-
@test*ₗ s +ₗ β *ₗ t) isa SectorArray
82-
@test TensorAlgebra.iscall((α *ₗ s +ₗ β *ₗ t).data)
83-
@test Array*ₗ s +ₗ β *ₗ t) α .* Array(s) .+ β .* Array(t)
84-
@test axes*ₗ s +ₗ β *ₗ t) == axes(s)
85-
@test Base.broadcasted(*, α, s) isa SectorArray
86-
@test TensorAlgebra.iscall(Base.broadcasted(*, α, s).data)
87-
8879
conjdiff = conj.(s) .- t ./ β
8980
@test conjdiff isa SectorArray
9081
@test conjdiff.data isa Matrix
9182
@test Array(conjdiff) conj.(Array(s)) .- Array(t) ./ β
9283
@test axes(conjdiff) == axes(s)
93-
@test conjed(s) isa SectorArray
94-
@test TensorAlgebra.iscall(conjed(s).data)
95-
@test (conjed(s) -ₗ (t /ₗ β)) isa SectorArray
96-
@test TensorAlgebra.iscall((conjed(s) -ₗ (t /ₗ β)).data)
97-
@test Array(conjed(s) -ₗ (t /ₗ β)) conj.(Array(s)) .- Array(t) ./ β
98-
@test axes(conjed(s) -ₗ (t /ₗ β)) == axes(s)
9984

10085
@test_throws ArgumentError s .* t
10186
@test_throws ArgumentError exp.(s)
10287
end
10388

104-
@testset "SectorArray scalar multiplication materializes on broadcast materialize" begin
89+
@testset "SectorArray scalar multiplication materializes on broadcast" begin
10590
s = SectorArray((U1(0), dual(U1(0))), randn!(Matrix{Float64}(undef, 2, 2)))
10691

107-
scaled = Base.broadcasted(*, 2, s)
108-
@test scaled isa SectorArray
109-
@test TensorAlgebra.iscall(scaled.data)
110-
materialized = Base.Broadcast.materialize(scaled)
92+
materialized = 2 .* s
11193
@test materialized isa SectorArray
11294
@test materialized.data isa Matrix
11395
@test materialized[1, 1] == 2 * s[1, 1]
@@ -120,14 +102,6 @@ end
120102
@test Array(scaled_mul) 2 .* Array(s)
121103
end
122104

123-
@testset "SectorArray lazy display" begin
124-
s = SectorArray((U1(0), dual(U1(0))), randn!(Matrix{Float64}(undef, 2, 2)))
125-
lazy = 2 *ₗ s
126-
shown = sprint(show, MIME("text/plain"), lazy)
127-
@test contains(shown, "SectorMatrix")
128-
@test contains(shown, "")
129-
end
130-
131105
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
132106
@testset "`contract` `GradedArray` (eltype=$elt)" for elt in elts
133107
@testset "matricize" begin

0 commit comments

Comments
 (0)