Skip to content

Commit 57ce0ab

Browse files
committed
define zeros and randn
1 parent 0244131 commit 57ce0ab

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FusionTensors"
22
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.4.1"
4+
version = "0.4.2"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -11,6 +11,7 @@ GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
1111
HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
1212
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1415
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
1516
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1617
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
@@ -25,6 +26,7 @@ GradedArrays = "0.4.13"
2526
HalfIntegers = "1.6"
2627
LRUCache = "1.6"
2728
LinearAlgebra = "1.10"
29+
Random = "1.11.0"
2830
Strided = "2.3"
2931
TensorAlgebra = "0.3.8"
3032
TensorProducts = "0.1.7"

src/fusiontensor/base_interface.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# This files defines Base functions for FusionTensor
22

33
using Accessors: @set
4-
5-
using BlockSparseArrays: @view!
4+
using Random: randn!
5+
using BlockSparseArrays: @view!, eachstoredblock
66
using TensorAlgebra: BlockedTuple, tuplemortar
77

88
set_data_matrix(ft::FusionTensor, data_matrix) = @set ft.data_matrix = data_matrix
@@ -81,6 +81,15 @@ end
8181

8282
Base.permutedims(ft::FusionTensor, args...) = fusiontensor_permutedims(ft, args...)
8383

84+
function Base.randn(::Type{T}, fta::FusionTensorAxes) where {T}
85+
ft = FusionTensor(T, fta)
86+
for m in eachstoredblock(data_matrix(ft))
87+
m = randn!(m)
88+
end
89+
return ft
90+
end
91+
Base.randn(fta::FusionTensorAxes) = randn(Float64, fta)
92+
8493
function Base.similar(ft::FusionTensor, T::Type)
8594
# reuse trees_block_mapping
8695

@@ -124,3 +133,6 @@ function Base.view(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
124133
charge_matrix = @view! data_matrix(ft)[trees_block_mapping(ft)[f1, f2]]
125134
return reshape(charge_matrix, charge_block_size(ft, f1, f2))
126135
end
136+
137+
Base.zeros(::Type{T}, fta::FusionTensorAxes) where {T} = FusionTensor(T, fta)
138+
Base.zeros(fta::FusionTensorAxes) = zeros(Float64, fta)

test/test_basics.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ include("setup.jl")
3535
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
3636
g2 = dual(gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1]))
3737

38+
fta = FusionTensorAxes((g1,), (g2,))
39+
ft0 = FusionTensor(Float64, fta)
40+
@test ft0 isa FusionTensor
41+
@test space_isequal(codomain_axis(ft0), g1)
42+
@test space_isequal(domain_axis(ft0), g2)
43+
3844
# check dual convention when initializing data_matrix
3945
ft0 = FusionTensor(Float64, (g1,), (g2,))
4046
@test ft0 isa FusionTensor
@@ -264,6 +270,13 @@ end
264270
@test_throws DimensionMismatch ft7 + ft3
265271
@test_throws DimensionMismatch ft7 - ft3
266272
@test_throws DimensionMismatch ft7 * ft3
273+
274+
fta = FusionTensorAxes((g1,), (g2, g2))
275+
@test zeros(fta) isa FusionTensor{Float64,3}
276+
@test zeros(ComplexF64, fta) isa FusionTensor{ComplexF64,3}
277+
ft9 = randn(ComplexF64, fta)
278+
@test ft9 isa FusionTensor{ComplexF64,3}
279+
@test all(data_matrix(ft9)[Block(1, 6)] .!= 0)
267280
end
268281

269282
@testset "missing SectorProduct" begin

0 commit comments

Comments
 (0)