Skip to content

Commit 9247e38

Browse files
committed
type stability improvements
1 parent fe3cfd4 commit 9247e38

File tree

4 files changed

+22
-22
lines changed

4 files changed

+22
-22
lines changed

src/fusiontrees/fusiontreeblocks.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,10 @@ end
607607

608608
const _FSBraidKey{I,N₁,N₂} = Tuple{<:FusionTreeBlock{I},Index2Tuple{N₁,N₂},Index2Tuple}
609609

610-
@cached function _fsbraid(key::_FSBraidKey{I,N₁,N₂})::Tuple{FusionTreeBlock{I,N₁,N₂},
610+
@cached function _fsbraid(key::_FSBraidKey{I,N₁,N₂})::Tuple{FusionTreeBlock{I,N₁,N₂,
611+
fusiontreetype(I,
612+
N₁,
613+
N₂)},
611614
Matrix{sectorscalartype(I)}} where {I,
612615
N₁,
613616
N₂}

src/fusiontrees/fusiontrees.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ end
136136
Base.:(==)(f₁::FusionTree, f₂::FusionTree) = false
137137

138138
# Facilitate getting correct fusion tree types
139-
function fusiontreetype(::Type{I}, N::Int) where {I<:Sector}
139+
Base.@assume_effects :foldable function fusiontreetype(::Type{I}, N::Int) where {I<:Sector}
140140
if N === 0
141141
FusionTree{I,0,0,0}
142142
elseif N === 1
@@ -145,7 +145,7 @@ function fusiontreetype(::Type{I}, N::Int) where {I<:Sector}
145145
FusionTree{I,N,N - 2,N - 1}
146146
end
147147
end
148-
function fusiontreetype(::Type{I}, N₁::Int, N₂::Int) where {I<:Sector}
148+
Base.@assume_effects :foldable function fusiontreetype(::Type{I}, N₁::Int, N₂::Int) where {I<:Sector}
149149
return Tuple{fusiontreetype(I, N₁),fusiontreetype(I, N₂)}
150150
end
151151

src/fusiontrees/manipulations.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,8 @@ function artin_braid(f::FusionTree{I,N}, i; inv::Bool=false) where {I,N}
884884
return fusiontreedict(I)(f′ => coeff)
885885
elseif FusionStyle(I) isa SimpleFusion
886886
local newtrees
887-
for c′ in intersect(a d, e conj(b))
887+
cs = collect(I, intersect(a d, e conj(b)))
888+
for c′ in cs
888889
coeff = oftype(oneT,
889890
if inv
890891
conj(Rsymbol(d, c, e) * Fsymbol(d, a, b, e, c′, c)) *
@@ -905,7 +906,8 @@ function artin_braid(f::FusionTree{I,N}, i; inv::Bool=false) where {I,N}
905906
return newtrees
906907
else # GenericFusion
907908
local newtrees
908-
for c′ in intersect(a d, e conj(b))
909+
cs = collect(I, intersect(a d, e conj(b)))
910+
for c′ in cs
909911
Rmat1 = inv ? Rsymbol(d, c, e)' : Rsymbol(c, d, e)
910912
Rmat2 = inv ? Rsymbol(d, a, c′)' : Rsymbol(a, d, c′)
911913
Fmat = Fsymbol(d, a, b, e, c′, c)

src/tensors/treetransformers.jl

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,13 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
6666
I = sectortype(Vsrc)
6767
T = sectorscalartype(I)
6868
N = numind(Vdst)
69+
N₁ = numout(Vsrc)
70+
N₂ = numin(Vsrc)
6971

7072
isdual_src = (map(isdual, codomain(Vsrc).spaces), map(isdual, domain(Vsrc).spaces))
7173

74+
data = Vector{_GenericTransformerData{T,N}}()
75+
7276
nthreads = get_num_transformer_threads()
7377
if nthreads > 1
7478
fusiontreeblocks = Vector{FusionTreeBlock{I,N₁,N₂,fusiontreetype(I, N₁, N₂)}}()
@@ -82,7 +86,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
8286
end
8387
end
8488

85-
data = Vector{_GenericTransformerData{T,N}}(undef, length(fusiontreeblocks))
89+
resize!(data, length(fusiontreeblocks))
8690
counter = Threads.Atomic{Int}(1)
8791
Threads.@sync for _ in 1:min(nthreads, length(fusiontreeblocks))
8892
Threads.@spawn begin
@@ -106,18 +110,13 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
106110
sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst,
107111
inds_dst)
108112

109-
@debug("Created recoupling block for uncoupled: $uncoupled",
110-
sz = size(matrix),
111-
sparsity = count(!iszero, matrix) / length(matrix))
112-
113-
data[local_counter] = (matrix, (sz_dst, newstructs_dst),
114-
(sz_src, newstructs_src))
113+
data1[local_counter] = (matrix, (sz_dst, newstructs_dst),
114+
(sz_src, newstructs_src))
115115
end
116116
end
117117
end
118+
transformer = GenericTreeTransformer{T,N}(data)
118119
else
119-
data = Vector{_GenericTransformerData{T,N}}()
120-
121120
isdual_src = (map(isdual, codomain(Vsrc).spaces), map(isdual, domain(Vsrc).spaces))
122121
for cod_uncoupled_src in sectors(codomain(Vsrc)),
123122
dom_uncoupled_src in sectors(domain(Vsrc))
@@ -140,24 +139,20 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
140139
sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst,
141140
inds_dst)
142141

143-
@debug("Created recoupling block for uncoupled: $uncoupled",
144-
sz = size(matrix), sparsity = count(!iszero, matrix) / length(matrix))
145-
146142
push!(data, (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src)))
147143
end
144+
transformer = GenericTreeTransformer{T,N}(data)
148145
end
149146

150-
transformer = GenericTreeTransformer{T,N}(data)
151-
152147
# sort by (approximate) weight to facilitate multi-threading strategies
153148
sort!(transformer)
154149

155150
Δt = Base.time() - t₀
156151

157152
@debug("TreeTransformer for $Vsrc to $Vdst via $p",
158-
nblocks = length(data),
159-
sz_median = size(data[cld(end, 2)][1], 1),
160-
sz_max = size(data[1][1], 1),
153+
nblocks = length(transformer.data),
154+
sz_median = size(transformer.data[cld(end, 2)][1], 1),
155+
sz_max = size(transformer.data[1][1], 1),
161156
Δt)
162157

163158
return transformer

0 commit comments

Comments
 (0)