Skip to content

Commit f2eee46

Browse files
committed
define randn(::RNG)
1 parent c9d14c4 commit f2eee46

File tree

4 files changed

+32
-9
lines changed

4 files changed

+32
-9
lines changed

src/fusiontensor/base_interface.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# This files defines Base functions for FusionTensor
22

33
using Accessors: @set
4-
using Random: randn!
54
using BlockSparseArrays: @view!, eachstoredblock
65
using TensorAlgebra: BlockedTuple, tuplemortar
76

src/fusiontensor/fusiontensor.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ using GradedArrays:
2020
sectors,
2121
space_isequal
2222
using LinearAlgebra: UniformScaling
23+
using Random: Random, AbstractRNG, randn!
2324
using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar
2425
using TensorProducts: tensor_product
2526
using TypeParameterAccessors: type_parameters
@@ -214,25 +215,32 @@ end
214215
Base.zeros(::Type{T}, fta::FusionTensorAxes) where {T} = FusionTensor(T, fta)
215216
Base.zeros(fta::FusionTensorAxes) = zeros(Float64, fta)
216217

217-
function Base.randn(::Type{T}, fta::FusionTensorAxes) where {T}
218+
function Base.randn(rng::AbstractRNG, ::Type{T}, fta::FusionTensorAxes) where {T}
218219
ft = FusionTensor(T, fta)
219220
for m in eachstoredblock(data_matrix(ft))
220-
randn!(m)
221+
randn!(rng, m)
221222
end
222223
return ft
223224
end
225+
Base.randn(rng::AbstractRNG, fta::FusionTensorAxes) = randn(rng, Float64, fta)
226+
Base.randn(::Type{T}, fta::FusionTensorAxes) where {T} = randn(Random.default_rng(), T, fta)
224227
Base.randn(fta::FusionTensorAxes) = randn(Float64, fta)
225228

226-
function FusionTensor(
227-
::UniformScaling, codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}}
228-
)
229+
function FusionTensor{T}(
230+
s::UniformScaling, codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}}
231+
) where {T}
229232
fta = FusionTensorAxes(codomain_legs, dual.(codomain_legs))
230-
ft = FusionTensor(Float64, fta)
233+
ft = FusionTensor(T, fta)
231234
for m in eachstoredblock(data_matrix(ft))
232-
m .= LinearAlgebra.I(size(m, 1))
235+
m .= s(size(m, 1))
233236
end
234237
return ft
235238
end
239+
function FusionTensor(
240+
s::UniformScaling, codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}}
241+
)
242+
return FusionTensor{Float64}(s, codomain_legs)
243+
end
236244

237245
# ================================ BlockArrays interface =================================
238246

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
55
FusionTensors = "e16ca583-1f51-4df0-8e12-57d32947d33e"
66
GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
89
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
910
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1011
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
@@ -18,6 +19,7 @@ BlockSparseArrays = "0.7"
1819
FusionTensors = "0.4"
1920
GradedArrays = "0.4"
2021
LinearAlgebra = "1.10.0"
22+
Random = "1.10"
2123
SafeTestsets = "0.1.0"
2224
Suppressor = "0.2.8"
2325
TensorAlgebra = "0.3"

test/test_basics.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ using GradedArrays:
2929
using TensorAlgebra: tuplemortar
3030
using TensorProducts: tensor_product
3131
using LinearAlgebra: LinearAlgebra
32+
using Random: Random
3233

3334
include("setup.jl")
3435

@@ -282,9 +283,13 @@ end
282283
fta = FusionTensorAxes((g1,), (g2, g3))
283284
@test zeros(fta) isa FusionTensor{Float64,3}
284285
@test zeros(ComplexF64, fta) isa FusionTensor{ComplexF64,3}
285-
ft1 = randn(ComplexF64, fta)
286+
287+
rng = Random.default_rng()
288+
ft1 = randn(rng, ComplexF64, fta)
286289
@test ft1 isa FusionTensor{ComplexF64,3}
287290
@test all(!=(0), data_matrix(ft1)[Block(1, 5)])
291+
@test randn(rng, fta) isa FusionTensor{Float64,3}
292+
@test randn(ComplexF64, fta) isa FusionTensor{ComplexF64,3}
288293
@test randn(fta) isa FusionTensor{Float64,3}
289294

290295
ft2 = FusionTensor(LinearAlgebra.I, (g1, g2))
@@ -295,6 +300,15 @@ end
295300
m = data_matrix(ft2)[Block(i, i)]
296301
@test m == LinearAlgebra.I(size(m, 1))
297302
end
303+
304+
ft2 = FusionTensor(3 * LinearAlgebra.I, (g1, g2))
305+
@test ft2 isa FusionTensor{Float64,4}
306+
@test axes(ft2) == FusionTensorAxes((g1, g2), dual.((g1, g2)))
307+
@test collect(eachblockstoredindex(data_matrix(ft2))) == map(i -> Block(i, i), 1:6)
308+
for i in 1:6
309+
m = data_matrix(ft2)[Block(i, i)]
310+
@test m == 3 * LinearAlgebra.I(size(m, 1))
311+
end
298312
end
299313

300314
@testset "missing SectorProduct" begin

0 commit comments

Comments
 (0)