Skip to content

Commit 9e38f28

Browse files
lkdvosJutho
authored andcommitted
Update fusiontreeblockstructure to have reshaped strides
1 parent b3aabea commit 9e38f28

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

src/spaces/homspace.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,11 @@ end
143143

144144
# Block and fusion tree ranges: structure information for building tensors
145145
#--------------------------------------------------------------------------
146-
struct FusionBlockStructure{I,F₁,F₂}
146+
struct FusionBlockStructure{I,N,F₁,F₂}
147147
totaldim::Int
148148
blockstructure::SectorDict{I,Tuple{Tuple{Int,Int},UnitRange{Int}}}
149149
fusiontreelist::Vector{Tuple{F₁,F₂}}
150-
fusiontreestructure::Vector{Tuple{Tuple{Int,Int},Tuple{Int,Int},Int}}
150+
fusiontreestructure::Vector{Tuple{NTuple{N,Int},NTuple{N,Int},Int}}
151151
fusiontreeindices::FusionTreeDict{Tuple{F₁,F₂},Int}
152152
end
153153

@@ -174,7 +174,7 @@ function fusionblockstructure(W::HomSpace, ::NoCache)
174174
# output structure
175175
blockstructure = SectorDict{I,Tuple{Tuple{Int,Int},UnitRange{Int}}}() # size, range
176176
fusiontreelist = Vector{Tuple{F₁,F₂}}()
177-
fusiontreestructure = Vector{Tuple{Tuple{Int,Int},Tuple{Int,Int},Int}}() # size, strides, offset
177+
fusiontreestructure = Vector{Tuple{NTuple{N₁ + N₂,Int},NTuple{N₁ + N₂,Int},Int}}() # size, strides, offset
178178

179179
# temporary data structures
180180
splittingtrees = Vector{F₁}()
@@ -203,7 +203,10 @@ function fusionblockstructure(W::HomSpace, ::NoCache)
203203
for (f₁, (offset₁, d₁)) in zip(splittingtrees, splittingstructure)
204204
push!(fusiontreelist, (f₁, f₂))
205205
totaloffset = blockoffset + offset₂ * blockdim₁ + offset₁
206-
push!(fusiontreestructure, ((d₁, d₂), strides, totaloffset))
206+
subsz = (dims(codom, f₁.uncoupled)..., dims(dom, f₂.uncoupled)...)
207+
@assert !any(isequal(0), subsz)
208+
substr = _subblock_strides(subsz, (d₁, d₂), strides)
209+
push!(fusiontreestructure, (subsz, substr, totaloffset))
207210
end
208211
offset₂ += d₂
209212
end
@@ -221,22 +224,28 @@ function fusionblockstructure(W::HomSpace, ::NoCache)
221224
fusiontreeindices[f₁₂] = i
222225
end
223226
totaldim = blockoffset
224-
structure = FusionBlockStructure{I,F₁,F₂}(totaldim, blockstructure,
225-
fusiontreelist, fusiontreestructure,
226-
fusiontreeindices)
227+
structure = FusionBlockStructure(totaldim, blockstructure,
228+
fusiontreelist, fusiontreestructure,
229+
fusiontreeindices)
227230
return structure
228231
end
229232

233+
function _subblock_strides(subsz, sz, str)
234+
sz_simplify = Strided.StridedViews._simplifydims(sz, str)
235+
return Strided.StridedViews._computereshapestrides(subsz, sz_simplify...)
236+
end
237+
230238
function fusionblockstructure(W::HomSpace, ::TaskLocalCache{D}) where {D}
231239
cache::D = get!(task_local_storage(), :_local_tensorstructure_cache) do
232240
return D()
233241
end
234242
N₁ = length(codomain(W))
235243
N₂ = length(domain(W))
244+
N = N₁ + N₂
236245
I = sectortype(W)
237246
F₁ = fusiontreetype(I, N₁)
238247
F₂ = fusiontreetype(I, N₂)
239-
structure::FusionBlockStructure{I,F₁,F₂} = get!(cache, W) do
248+
structure::FusionBlockStructure{I,N,F₁,F₂} = get!(cache, W) do
240249
return fusionblockstructure(W, NoCache())
241250
end
242251
return structure
@@ -248,10 +257,11 @@ function fusionblockstructure(W::HomSpace, ::GlobalLRUCache)
248257
cache = GLOBAL_FUSIONBLOCKSTRUCTURE_CACHE
249258
N₁ = length(codomain(W))
250259
N₂ = length(domain(W))
260+
N = N₁ + N₂
251261
I = sectortype(W)
252262
F₁ = fusiontreetype(I, N₁)
253263
F₂ = fusiontreetype(I, N₂)
254-
structure::FusionBlockStructure{I,F₁,F₂} = get!(cache, W) do
264+
structure::FusionBlockStructure{I,N,F₁,F₂} = get!(cache, W) do
255265
return fusionblockstructure(W, NoCache())
256266
end
257267
return structure

src/tensors/tensor.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,9 +453,7 @@ column indices correspond to `f₂.uncoupled`.
453453
@inbounds begin
454454
i = structure.fusiontreeindices[(f₁, f₂)]
455455
sz, str, offset = structure.fusiontreestructure[i]
456-
subblock = StridedView(t.data, sz, str, offset)
457-
d = (dims(codomain(t), f₁.uncoupled)..., dims(domain(t), f₂.uncoupled)...)
458-
return sreshape(subblock, d)
456+
return StridedView(t.data, sz, str, offset)
459457
end
460458
end
461459
# The following is probably worth special casing for trivial tensors

0 commit comments

Comments
 (0)