Skip to content

Commit 76cbf2c

Browse files
committed
undef init
1 parent f2eee46 commit 76cbf2c

File tree

5 files changed

+37
-21
lines changed

5 files changed

+37
-21
lines changed

src/fusiontensor/array_cast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ function to_fusiontensor_no_checknorm(
6464
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
6565
domain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
6666
)
67-
ft = FusionTensor(eltype(blockarray), codomain_legs, domain_legs)
67+
ft = FusionTensor{eltype(blockarray)}(undef, codomain_legs, domain_legs)
6868
for (f1, f2) in keys(trees_block_mapping(ft))
6969
b = findblock(ft, f1, f2)
7070
ft[f1, f2] = contract_fusion_trees(blockarray[b], f1, f2)

src/fusiontensor/base_interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ function Base.similar(
104104
return similar(ft, T, tuplemortar(new_axes))
105105
end
106106
function Base.similar(::FusionTensor, ::Type{T}, new_axes::BlockedTuple{2}) where {T}
107-
return FusionTensor(T, new_axes)
107+
return FusionTensor{T}(undef, new_axes)
108108
end
109109

110110
Base.show(io::IO, ft::FusionTensor) = print(io, "$(ndims(ft))-dim FusionTensor")

src/fusiontensor/fusiontensor.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,12 @@ function FusionTensor(
176176
end
177177

178178
# empty matrix
179-
function FusionTensor(elt::Type, legs::FusionTensorAxes)
179+
function FusionTensor{T}(::UndefInitializer, legs::FusionTensorAxes) where {T}
180180
S = sector_type(legs)
181181
row_axis, codomain_trees_to_ranges = fuse_axes(S, codomain(legs))
182182
col_axis, domain_trees_to_ranges = flip_domain(fuse_axes(S, dual.(domain(legs)))...)
183183

184-
mat = initialize_data_matrix(elt, row_axis, col_axis)
184+
mat = initialize_data_matrix(T, row_axis, col_axis)
185185
tree_to_block_mapping = intersect_codomain_domain(
186186
codomain_trees_to_ranges, domain_trees_to_ranges
187187
)
@@ -191,7 +191,7 @@ end
191191
#constructor from precomputed data_matrix
192192
function FusionTensor(mat::AbstractMatrix, legs::FusionTensorAxes)
193193
# init with empty data_matrix to construct trees_block_mapping
194-
ft = FusionTensor(eltype(mat), legs)
194+
ft = FusionTensor{eltype(mat)}(undef, legs)
195195
for b in eachblockstoredindex(mat)
196196
b in eachblockstoredindex(data_matrix(ft)) ||
197197
throw(ArgumentError("matrix block $b is not allowed"))
@@ -201,6 +201,9 @@ function FusionTensor(mat::AbstractMatrix, legs::FusionTensorAxes)
201201
end
202202

203203
FusionTensor(x, legs::BlockedTuple{2}) = FusionTensor(x, FusionTensorAxes(legs))
204+
function FusionTensor{T}(x, legs::BlockedTuple{2}) where {T}
205+
return FusionTensor{T}(x, FusionTensorAxes(legs))
206+
end
204207

205208
# constructor from split axes
206209
function FusionTensor(
@@ -211,15 +214,25 @@ function FusionTensor(
211214
return FusionTensor(x, tuplemortar((codomain_legs, domain_legs)))
212215
end
213216

217+
function FusionTensor{T}(
218+
x,
219+
codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
220+
domain_legs::Tuple{Vararg{AbstractGradedUnitRange}},
221+
) where {T}
222+
return FusionTensor{T}(x, tuplemortar((codomain_legs, domain_legs)))
223+
end
224+
214225
# specific constructors
215-
Base.zeros(::Type{T}, fta::FusionTensorAxes) where {T} = FusionTensor(T, fta)
226+
function Base.zeros(::Type{T}, fta::FusionTensorAxes) where {T}
227+
ft = FusionTensor{T}(undef, fta)
228+
map(m -> fill!(m, zero(T)), eachstoredblock(data_matrix(ft)))
229+
return ft
230+
end
216231
Base.zeros(fta::FusionTensorAxes) = zeros(Float64, fta)
217232

218233
function Base.randn(rng::AbstractRNG, ::Type{T}, fta::FusionTensorAxes) where {T}
219-
ft = FusionTensor(T, fta)
220-
for m in eachstoredblock(data_matrix(ft))
221-
randn!(rng, m)
222-
end
234+
ft = FusionTensor{T}(undef, fta)
235+
map(m -> randn!(rng, m), eachstoredblock(data_matrix(ft)))
223236
return ft
224237
end
225238
Base.randn(rng::AbstractRNG, fta::FusionTensorAxes) = randn(rng, Float64, fta)
@@ -230,7 +243,7 @@ function FusionTensor{T}(
230243
s::UniformScaling, codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}}
231244
) where {T}
232245
fta = FusionTensorAxes(codomain_legs, dual.(codomain_legs))
233-
ft = FusionTensor(T, fta)
246+
ft = FusionTensor{T}(undef, fta)
234247
for m in eachstoredblock(data_matrix(ft))
235248
m .= s(size(m, 1))
236249
end

src/permutedims/permutedims.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,15 @@ function fusiontensor_permutedims(ft, biperm::BlockedPermutation{2})
3838
end
3939
end
4040

41-
new_ft = FusionTensor(eltype(ft), axes(ft)[biperm])
41+
new_ft = FusionTensor{eltype(ft)}(undef, axes(ft)[biperm])
4242
fusiontensor_permutedims!(new_ft, ft, Tuple(biperm))
4343
return new_ft
4444
end
4545

4646
function fusiontensor_permutedims!(
4747
new_ft::FusionTensor{T,N}, old_ft::FusionTensor{T,N}, flatperm::NTuple{N,Integer}
4848
) where {T,N}
49+
map(m -> fill!(m, zero(T)), eachstoredblock(data_matrix(new_ft)))
4950
unitary = compute_unitary(new_ft, old_ft, flatperm)
5051
for p in unitary
5152
old_trees, new_trees = first(p)

test/test_basics.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@ include("setup.jl")
3838
g2 = dual(gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1]))
3939

4040
fta = FusionTensorAxes((g1,), (g2,))
41-
ft0 = FusionTensor(Float64, fta)
41+
ft0 = FusionTensor{Float64}(undef, fta)
4242
@test ft0 isa FusionTensor
4343
@test space_isequal(codomain_axis(ft0), g1)
4444
@test space_isequal(domain_axis(ft0), g2)
4545

4646
# check dual convention when initializing data_matrix
47-
ft0 = FusionTensor(Float64, (g1,), (g2,))
47+
ft0 = FusionTensor{Float64}(undef, (g1,), (g2,))
4848
@test ft0 isa FusionTensor
4949
@test space_isequal(codomain_axis(ft0), g1)
5050
@test space_isequal(domain_axis(ft0), g2)
@@ -154,7 +154,7 @@ end
154154
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
155155

156156
# one row axis
157-
ft1 = FusionTensor(Float64, (g1,), ())
157+
ft1 = FusionTensor{Float64}(undef, (g1,), ())
158158
@test ndims_codomain(ft1) == 1
159159
@test ndims_domain(ft1) == 0
160160
@test ndims(ft1) == 1
@@ -164,7 +164,7 @@ end
164164
@test sector_type(ft1) === sector_type(g1)
165165

166166
# one column axis
167-
ft2 = FusionTensor(Float64, (), (g1,))
167+
ft2 = FusionTensor{Float64}(undef, (), (g1,))
168168
@test ndims_codomain(ft2) == 0
169169
@test ndims_domain(ft2) == 1
170170
@test ndims(ft2) == 1
@@ -174,7 +174,7 @@ end
174174
@test sector_type(ft2) === sector_type(g1)
175175

176176
# zero axis
177-
ft3 = FusionTensor(Float64, (), ())
177+
ft3 = FusionTensor{Float64}(undef, (), ())
178178
@test ndims_codomain(ft3) == 0
179179
@test ndims_domain(ft3) == 0
180180
@test ndims(ft3) == 0
@@ -189,7 +189,7 @@ end
189189
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
190190
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
191191
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])
192-
ft3 = FusionTensor(Float64, (g1, g2), (g3, g4))
192+
ft3 = FusionTensor{Float64}(undef, (g1, g2), (g3, g4))
193193
@test isnothing(check_sanity(ft3))
194194

195195
ft4 = +ft3
@@ -268,7 +268,7 @@ end
268268
@test space_isequal(dual(g4), codomain_axes(ad)[2])
269269
@test isnothing(check_sanity(ad))
270270

271-
ft7 = FusionTensor(Float64, (g1,), (g2, g3, g4))
271+
ft7 = FusionTensor{Float64}(undef, (g1,), (g2, g3, g4))
272272
@test_throws DimensionMismatch ft7 + ft3
273273
@test_throws DimensionMismatch ft7 - ft3
274274
@test_throws DimensionMismatch ft7 * ft3
@@ -309,6 +309,8 @@ end
309309
m = data_matrix(ft2)[Block(i, i)]
310310
@test m == 3 * LinearAlgebra.I(size(m, 1))
311311
end
312+
313+
@test FusionTensor{ComplexF64}(LinearAlgebra.I, (g1, g2)) isa FusionTensor{ComplexF64,4}
312314
end
313315

314316
@testset "missing SectorProduct" begin
@@ -317,7 +319,7 @@ end
317319
g3 = gradedrange([SectorProduct(U1(1), SU2(1//2), Z{2}(1)) => 1])
318320
S = sector_type(g3)
319321

320-
ft = FusionTensor(Float64, (g1,), (dual(g2), dual(g3)))
322+
ft = FusionTensor{Float64}(undef, (g1,), (dual(g2), dual(g3)))
321323
@test sector_type(ft) === S
322324
gr = gradedrange([SectorProduct(U1(1), SU2(0), Z{2}(0)) => 1])
323325
@test space_isequal(codomain_axis(ft), gr)
@@ -332,7 +334,7 @@ end
332334
gABC = tensor_product(gA, gB, gC)
333335
S = sector_type(gABC)
334336

335-
ft = FusionTensor(Float64, (gA, gB), (dual(gA), dual(gB), gC))
337+
ft = FusionTensor{Float64}(undef, (gA, gB), (dual(gA), dual(gB), gC))
336338
@test sector_type(ft) === S
337339
@test space_isequal(codomain_axis(ft), gABC)
338340
@test space_isequal(domain_axis(ft), dual(gABC))

0 commit comments

Comments
 (0)