Skip to content

Commit fe461fd

Browse files
committed
define constructors
1 parent 57ce0ab commit fe461fd

File tree

3 files changed

+47
-17
lines changed

3 files changed

+47
-17
lines changed

src/fusiontensor/base_interface.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,6 @@ end
8181

8282
Base.permutedims(ft::FusionTensor, args...) = fusiontensor_permutedims(ft, args...)
8383

84-
function Base.randn(::Type{T}, fta::FusionTensorAxes) where {T}
85-
ft = FusionTensor(T, fta)
86-
for m in eachstoredblock(data_matrix(ft))
87-
m = randn!(m)
88-
end
89-
return ft
90-
end
91-
Base.randn(fta::FusionTensorAxes) = randn(Float64, fta)
92-
9384
function Base.similar(ft::FusionTensor, T::Type)
9485
# reuse trees_block_mapping
9586

@@ -133,6 +124,3 @@ function Base.view(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
133124
charge_matrix = @view! data_matrix(ft)[trees_block_mapping(ft)[f1, f2]]
134125
return reshape(charge_matrix, charge_block_size(ft, f1, f2))
135126
end
136-
137-
Base.zeros(::Type{T}, fta::FusionTensorAxes) where {T} = FusionTensor(T, fta)
138-
Base.zeros(fta::FusionTensorAxes) = zeros(Float64, fta)

src/fusiontensor/fusiontensor.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ using GradedArrays:
1919
sectormergesort,
2020
sectors,
2121
space_isequal
22+
using LinearAlgebra: UniformScaling
2223
using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar
2324
using TensorProducts: tensor_product
2425
using TypeParameterAccessors: type_parameters
@@ -209,6 +210,30 @@ function FusionTensor(
209210
return FusionTensor(x, tuplemortar((codomain_legs, domain_legs)))
210211
end
211212

213+
# specific constructors
214+
Base.zeros(::Type{T}, fta::FusionTensorAxes) where {T} = FusionTensor(T, fta)
215+
Base.zeros(fta::FusionTensorAxes) = zeros(Float64, fta)
216+
217+
function Base.randn(::Type{T}, fta::FusionTensorAxes) where {T}
218+
ft = FusionTensor(T, fta)
219+
for m in eachstoredblock(data_matrix(ft))
220+
m = randn!(m)
221+
end
222+
return ft
223+
end
224+
Base.randn(fta::FusionTensorAxes) = randn(Float64, fta)
225+
226+
function FusionTensor(
227+
::UniformScaling, codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}}
228+
)
229+
fta = FusionTensorAxes(codomain_legs, dual.(codomain_legs))
230+
ft = FusionTensor(Float64, fta)
231+
for m in eachstoredblock(data_matrix(ft))
232+
m .= LinearAlgebra.I(size(m, 1))
233+
end
234+
return ft
235+
end
236+
212237
# ================================ BlockArrays interface =================================
213238

214239
function BlockArrays.findblock(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)

test/test_basics.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Test: @test, @test_throws, @testset
22

33
using BlockArrays: Block
4-
using BlockSparseArrays: BlockSparseArray
4+
using BlockSparseArrays: BlockSparseArray, eachblockstoredindex
55
using FusionTensors:
66
FusionTensor,
77
FusionTensorAxes,
@@ -28,6 +28,7 @@ using GradedArrays:
2828
space_isequal
2929
using TensorAlgebra: tuplemortar
3030
using TensorProducts: tensor_product
31+
using LinearAlgebra: LinearAlgebra
3132

3233
include("setup.jl")
3334

@@ -270,13 +271,29 @@ end
270271
@test_throws DimensionMismatch ft7 + ft3
271272
@test_throws DimensionMismatch ft7 - ft3
272273
@test_throws DimensionMismatch ft7 * ft3
274+
end
275+
276+
@testset "specific constructors" begin
277+
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
278+
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
279+
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
280+
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])
273281

274-
fta = FusionTensorAxes((g1,), (g2, g2))
282+
fta = FusionTensorAxes((g1,), (g2, g3))
275283
@test zeros(fta) isa FusionTensor{Float64,3}
276284
@test zeros(ComplexF64, fta) isa FusionTensor{ComplexF64,3}
277-
ft9 = randn(ComplexF64, fta)
278-
@test ft9 isa FusionTensor{ComplexF64,3}
279-
@test all(data_matrix(ft9)[Block(1, 6)] .!= 0)
285+
ft1 = randn(ComplexF64, fta)
286+
@test ft1 isa FusionTensor{ComplexF64,3}
287+
@test all(data_matrix(ft1)[Block(1, 5)] .!= 0)
288+
289+
ft2 = FusionTensor(LinearAlgebra.I, (g1, g2))
290+
@test ft2 isa FusionTensor{Float64,4}
291+
@test axes(ft2) == FusionTensorAxes((g1, g2), dual.((g1, g2)))
292+
@test collect(eachblockstoredindex(data_matrix(ft2))) == map(i -> Block(i, i), 1:6)
293+
for i in 1:6
294+
m = data_matrix(ft2)[Block(i, i)]
295+
@test m == LinearAlgebra.I(size(m, 1))
296+
end
280297
end
281298

282299
@testset "missing SectorProduct" begin

0 commit comments

Comments
 (0)