Skip to content

Commit c345a71

Browse files
authored
add dedicated constructors (#59)
1 parent 0244131 commit c345a71

File tree

11 files changed

+125
-34
lines changed

11 files changed

+125
-34
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FusionTensors"
22
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.4.1"
4+
version = "0.5.0"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -11,6 +11,7 @@ GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
1111
HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
1212
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1415
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
1516
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1617
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
@@ -25,6 +26,7 @@ GradedArrays = "0.4.13"
2526
HalfIntegers = "1.6"
2627
LRUCache = "1.6"
2728
LinearAlgebra = "1.10"
29+
Random = "1.10"
2830
Strided = "2.3"
2931
TensorAlgebra = "0.3.8"
3032
TensorProducts = "0.1.7"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
55

66
[compat]
77
Documenter = "1.10.0"
8-
FusionTensors = "0.4"
8+
FusionTensors = "0.5"
99
Literate = "2.20.1"

src/fusiontensor/array_cast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ function to_fusiontensor_no_checknorm(
6464
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
6565
domain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
6666
)
67-
ft = FusionTensor(eltype(blockarray), codomain_legs, domain_legs)
67+
ft = FusionTensor{eltype(blockarray)}(undef, codomain_legs, domain_legs)
6868
for (f1, f2) in keys(trees_block_mapping(ft))
6969
b = findblock(ft, f1, f2)
7070
ft[f1, f2] = contract_fusion_trees(blockarray[b], f1, f2)

src/fusiontensor/base_interface.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# This files defines Base functions for FusionTensor
22

33
using Accessors: @set
4-
5-
using BlockSparseArrays: @view!
4+
using BlockSparseArrays: @view!, eachstoredblock
65
using TensorAlgebra: BlockedTuple, tuplemortar
76

87
set_data_matrix(ft::FusionTensor, data_matrix) = @set ft.data_matrix = data_matrix
@@ -105,7 +104,7 @@ function Base.similar(
105104
return similar(ft, T, tuplemortar(new_axes))
106105
end
107106
function Base.similar(::FusionTensor, ::Type{T}, new_axes::BlockedTuple{2}) where {T}
108-
return FusionTensor(T, new_axes)
107+
return FusionTensor{T}(undef, new_axes)
109108
end
110109

111110
Base.show(io::IO, ft::FusionTensor) = print(io, "$(ndims(ft))-dim FusionTensor")

src/fusiontensor/fusiontensor.jl

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ using GradedArrays:
1919
sectormergesort,
2020
sectors,
2121
space_isequal
22+
using LinearAlgebra: UniformScaling
23+
using Random: Random, AbstractRNG, randn!
2224
using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar
2325
using TensorProducts: tensor_product
2426
using TypeParameterAccessors: type_parameters
@@ -174,12 +176,12 @@ function FusionTensor(
174176
end
175177

176178
# empty matrix
177-
function FusionTensor(elt::Type, legs::FusionTensorAxes)
179+
function FusionTensor{T}(::UndefInitializer, legs::FusionTensorAxes) where {T}
178180
S = sector_type(legs)
179181
row_axis, codomain_trees_to_ranges = fuse_axes(S, codomain(legs))
180182
col_axis, domain_trees_to_ranges = flip_domain(fuse_axes(S, dual.(domain(legs)))...)
181183

182-
mat = initialize_data_matrix(elt, row_axis, col_axis)
184+
mat = initialize_data_matrix(T, row_axis, col_axis)
183185
tree_to_block_mapping = intersect_codomain_domain(
184186
codomain_trees_to_ranges, domain_trees_to_ranges
185187
)
@@ -189,7 +191,7 @@ end
189191
#constructor from precomputed data_matrix
190192
function FusionTensor(mat::AbstractMatrix, legs::FusionTensorAxes)
191193
# init with empty data_matrix to construct trees_block_mapping
192-
ft = FusionTensor(eltype(mat), legs)
194+
ft = FusionTensor{eltype(mat)}(undef, legs)
193195
for b in eachblockstoredindex(mat)
194196
b in eachblockstoredindex(data_matrix(ft)) ||
195197
throw(ArgumentError("matrix block $b is not allowed"))
@@ -199,6 +201,9 @@ function FusionTensor(mat::AbstractMatrix, legs::FusionTensorAxes)
199201
end
200202

201203
FusionTensor(x, legs::BlockedTuple{2}) = FusionTensor(x, FusionTensorAxes(legs))
204+
function FusionTensor{T}(x, legs::BlockedTuple{2}) where {T}
205+
return FusionTensor{T}(x, FusionTensorAxes(legs))
206+
end
202207

203208
# constructor from split axes
204209
function FusionTensor(
@@ -209,6 +214,47 @@ function FusionTensor(
209214
return FusionTensor(x, tuplemortar((codomain_legs, domain_legs)))
210215
end
211216

217+
function FusionTensor{T}(
218+
x,
219+
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
220+
domain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
221+
) where {T}
222+
return FusionTensor{T}(x, tuplemortar((codomain_legs, domain_legs)))
223+
end
224+
225+
# specific constructors
226+
function Base.zeros(::Type{T}, fta::FusionTensorAxes) where {T}
227+
ft = FusionTensor{T}(undef, fta)
228+
foreach(m -> fill!(m, zero(T)), eachstoredblock(data_matrix(ft)))
229+
return ft
230+
end
231+
Base.zeros(fta::FusionTensorAxes) = zeros(Float64, fta)
232+
233+
function Base.randn(rng::AbstractRNG, ::Type{T}, fta::FusionTensorAxes) where {T}
234+
ft = FusionTensor{T}(undef, fta)
235+
foreach(m -> randn!(rng, m), eachstoredblock(data_matrix(ft)))
236+
return ft
237+
end
238+
Base.randn(rng::AbstractRNG, fta::FusionTensorAxes) = randn(rng, Float64, fta)
239+
Base.randn(::Type{T}, fta::FusionTensorAxes) where {T} = randn(Random.default_rng(), T, fta)
240+
Base.randn(fta::FusionTensorAxes) = randn(Float64, fta)
241+
242+
function FusionTensor{T}(
243+
s::UniformScaling, codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}}
244+
) where {T}
245+
fta = FusionTensorAxes(codomain_legs, dual.(codomain_legs))
246+
ft = FusionTensor{T}(undef, fta)
247+
for m in eachstoredblock(data_matrix(ft))
248+
m .= s(size(m, 1))
249+
end
250+
return ft
251+
end
252+
function FusionTensor(
253+
s::UniformScaling, codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}}
254+
)
255+
return FusionTensor{Float64}(s, codomain_legs)
256+
end
257+
212258
# ================================ BlockArrays interface =================================
213259

214260
function BlockArrays.findblock(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)

src/permutedims/permutedims.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,19 @@ function fusiontensor_permutedims(ft, biperm::BlockedPermutation{2})
3838
end
3939
end
4040

41-
new_ft = FusionTensor(eltype(ft), axes(ft)[biperm])
41+
new_ft = FusionTensor{eltype(ft)}(undef, axes(ft)[biperm])
4242
fusiontensor_permutedims!(new_ft, ft, Tuple(biperm))
4343
return new_ft
4444
end
4545

4646
function fusiontensor_permutedims!(
4747
new_ft::FusionTensor{T,N}, old_ft::FusionTensor{T,N}, flatperm::NTuple{N,Integer}
4848
) where {T,N}
49+
foreach(m -> fill!(m, zero(T)), eachstoredblock(data_matrix(new_ft)))
4950
unitary = compute_unitary(new_ft, old_ft, flatperm)
50-
for p in unitary
51-
old_trees, new_trees = first(p)
51+
for ((old_trees, new_trees), coeff) in unitary
5252
new_block = view(new_ft, new_trees...)
5353
old_block = view(old_ft, old_trees...)
54-
@strided new_block .+= last(p) .* permutedims(old_block, flatperm)
54+
@strided new_block .+= coeff .* permutedims(old_block, flatperm)
5555
end
5656
end

test/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
55
FusionTensors = "e16ca583-1f51-4df0-8e12-57d32947d33e"
66
GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
89
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
910
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1011
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
@@ -15,9 +16,10 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1516
Aqua = "0.8.11"
1617
BlockArrays = "1.6"
1718
BlockSparseArrays = "0.7"
18-
FusionTensors = "0.4"
19+
FusionTensors = "0.5"
1920
GradedArrays = "0.4"
2021
LinearAlgebra = "1.10.0"
22+
Random = "1.10"
2123
SafeTestsets = "0.1.0"
2224
Suppressor = "0.2.8"
2325
TensorAlgebra = "0.3"

test/test_basics.jl

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Test: @test, @test_throws, @testset
22

33
using BlockArrays: Block
4-
using BlockSparseArrays: BlockSparseArray
4+
using BlockSparseArrays: BlockSparseArray, eachblockstoredindex
55
using FusionTensors:
66
FusionTensor,
77
FusionTensorAxes,
@@ -28,15 +28,23 @@ using GradedArrays:
2828
space_isequal
2929
using TensorAlgebra: tuplemortar
3030
using TensorProducts: tensor_product
31+
using LinearAlgebra: LinearAlgebra
32+
using Random: Random
3133

3234
include("setup.jl")
3335

3436
@testset "Fusion matrix" begin
3537
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
3638
g2 = dual(gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1]))
3739

40+
fta = FusionTensorAxes((g1,), (g2,))
41+
ft0 = FusionTensor{Float64}(undef, fta)
42+
@test ft0 isa FusionTensor
43+
@test space_isequal(codomain_axis(ft0), g1)
44+
@test space_isequal(domain_axis(ft0), g2)
45+
3846
# check dual convention when initializing data_matrix
39-
ft0 = FusionTensor(Float64, (g1,), (g2,))
47+
ft0 = FusionTensor{Float64}(undef, (g1,), (g2,))
4048
@test ft0 isa FusionTensor
4149
@test space_isequal(codomain_axis(ft0), g1)
4250
@test space_isequal(domain_axis(ft0), g2)
@@ -146,7 +154,7 @@ end
146154
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
147155

148156
# one row axis
149-
ft1 = FusionTensor(Float64, (g1,), ())
157+
ft1 = FusionTensor{Float64}(undef, (g1,), ())
150158
@test ndims_codomain(ft1) == 1
151159
@test ndims_domain(ft1) == 0
152160
@test ndims(ft1) == 1
@@ -156,7 +164,7 @@ end
156164
@test sector_type(ft1) === sector_type(g1)
157165

158166
# one column axis
159-
ft2 = FusionTensor(Float64, (), (g1,))
167+
ft2 = FusionTensor{Float64}(undef, (), (g1,))
160168
@test ndims_codomain(ft2) == 0
161169
@test ndims_domain(ft2) == 1
162170
@test ndims(ft2) == 1
@@ -166,7 +174,7 @@ end
166174
@test sector_type(ft2) === sector_type(g1)
167175

168176
# zero axis
169-
ft3 = FusionTensor(Float64, (), ())
177+
ft3 = FusionTensor{Float64}(undef, (), ())
170178
@test ndims_codomain(ft3) == 0
171179
@test ndims_domain(ft3) == 0
172180
@test ndims(ft3) == 0
@@ -181,7 +189,7 @@ end
181189
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
182190
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
183191
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])
184-
ft3 = FusionTensor(Float64, (g1, g2), (g3, g4))
192+
ft3 = FusionTensor{Float64}(undef, (g1, g2), (g3, g4))
185193
@test isnothing(check_sanity(ft3))
186194

187195
ft4 = +ft3
@@ -260,19 +268,58 @@ end
260268
@test space_isequal(dual(g4), codomain_axes(ad)[2])
261269
@test isnothing(check_sanity(ad))
262270

263-
ft7 = FusionTensor(Float64, (g1,), (g2, g3, g4))
271+
ft7 = FusionTensor{Float64}(undef, (g1,), (g2, g3, g4))
264272
@test_throws DimensionMismatch ft7 + ft3
265273
@test_throws DimensionMismatch ft7 - ft3
266274
@test_throws DimensionMismatch ft7 * ft3
267275
end
268276

277+
@testset "specific constructors" begin
278+
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
279+
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
280+
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
281+
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])
282+
283+
fta = FusionTensorAxes((g1,), (g2, g3))
284+
@test zeros(fta) isa FusionTensor{Float64,3}
285+
@test zeros(ComplexF64, fta) isa FusionTensor{ComplexF64,3}
286+
287+
rng = Random.default_rng()
288+
ft1 = randn(rng, ComplexF64, fta)
289+
@test ft1 isa FusionTensor{ComplexF64,3}
290+
@test all(!=(0), data_matrix(ft1)[Block(1, 5)])
291+
@test randn(rng, fta) isa FusionTensor{Float64,3}
292+
@test randn(ComplexF64, fta) isa FusionTensor{ComplexF64,3}
293+
@test randn(fta) isa FusionTensor{Float64,3}
294+
295+
ft2 = FusionTensor(LinearAlgebra.I, (g1, g2))
296+
@test ft2 isa FusionTensor{Float64,4}
297+
@test axes(ft2) == FusionTensorAxes((g1, g2), dual.((g1, g2)))
298+
@test collect(eachblockstoredindex(data_matrix(ft2))) == map(i -> Block(i, i), 1:6)
299+
for i in 1:6
300+
m = data_matrix(ft2)[Block(i, i)]
301+
@test m == LinearAlgebra.I(size(m, 1))
302+
end
303+
304+
ft2 = FusionTensor(3 * LinearAlgebra.I, (g1, g2))
305+
@test ft2 isa FusionTensor{Float64,4}
306+
@test axes(ft2) == FusionTensorAxes((g1, g2), dual.((g1, g2)))
307+
@test collect(eachblockstoredindex(data_matrix(ft2))) == map(i -> Block(i, i), 1:6)
308+
for i in 1:6
309+
m = data_matrix(ft2)[Block(i, i)]
310+
@test m == 3 * LinearAlgebra.I(size(m, 1))
311+
end
312+
313+
@test FusionTensor{ComplexF64}(LinearAlgebra.I, (g1, g2)) isa FusionTensor{ComplexF64,4}
314+
end
315+
269316
@testset "missing SectorProduct" begin
270317
g1 = gradedrange([SectorProduct(U1(1)) => 1])
271318
g2 = gradedrange([SectorProduct(U1(1), SU2(1//2)) => 1])
272319
g3 = gradedrange([SectorProduct(U1(1), SU2(1//2), Z{2}(1)) => 1])
273320
S = sector_type(g3)
274321

275-
ft = FusionTensor(Float64, (g1,), (dual(g2), dual(g3)))
322+
ft = FusionTensor{Float64}(undef, (g1,), (dual(g2), dual(g3)))
276323
@test sector_type(ft) === S
277324
gr = gradedrange([SectorProduct(U1(1), SU2(0), Z{2}(0)) => 1])
278325
@test space_isequal(codomain_axis(ft), gr)
@@ -287,7 +334,7 @@ end
287334
gABC = tensor_product(gA, gB, gC)
288335
S = sector_type(gABC)
289336

290-
ft = FusionTensor(Float64, (gA, gB), (dual(gA), dual(gB), gC))
337+
ft = FusionTensor{Float64}(undef, (gA, gB), (dual(gA), dual(gB), gC))
291338
@test sector_type(ft) === S
292339
@test space_isequal(codomain_axis(ft), gABC)
293340
@test space_isequal(domain_axis(ft), dual(gABC))

test/test_contraction.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ include("setup.jl")
1414
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
1515
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])
1616

17-
ft1 = FusionTensor(Float64, (g1, g2), (g3, g4))
17+
ft1 = FusionTensor{Float64}(undef, (g1, g2), (g3, g4))
1818
@test isnothing(check_sanity(ft1))
1919

20-
ft2 = FusionTensor(Float64, dual.((g3, g4)), (g1,))
20+
ft2 = FusionTensor{Float64}(undef, dual.((g3, g4)), (g1,))
2121
@test isnothing(check_sanity(ft2))
2222

2323
ft3 = ft1 * ft2 # tensor contraction
@@ -43,9 +43,9 @@ end
4343
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
4444
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])
4545

46-
ft1 = FusionTensor(Float64, (g1, g2), (g3, g4))
47-
ft2 = FusionTensor(Float64, dual.((g3, g4)), (dual(g1),))
48-
ft3 = FusionTensor(Float64, dual.((g3, g4)), dual.((g1, g2)))
46+
ft1 = FusionTensor{Float64}(undef, (g1, g2), (g3, g4))
47+
ft2 = FusionTensor{Float64}(undef, dual.((g3, g4)), (dual(g1),))
48+
ft3 = FusionTensor{Float64}(undef, dual.((g3, g4)), dual.((g1, g2)))
4949

5050
ft4, legs = contract(ft1, (1, 2, 3, 4), ft2, (3, 4, 5))
5151
@test legs == tuplemortar(((1, 2), (5,)))

test/test_linear_algebra.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,6 @@ include("setup.jl")
2323
gsu2 = gradedrange([SU2(1 / 2) => 1])
2424

2525
for g in [g0, gu1, gsu2]
26-
ft0 = FusionTensor(Float64, (g, g), (dual(g), dual(g)))
27-
@test isnothing(check_sanity(ft0))
28-
@test norm(ft0) == 0
29-
@test tr(ft0) == 0
30-
3126
ft = to_fusiontensor(sdst, (g, g), (dual(g), dual(g)))
3227
@test isnothing(check_sanity(ft))
3328
@test norm(ft) 3 / 2

0 commit comments

Comments
 (0)