@@ -143,11 +143,11 @@ end
143143
144144# Block and fusion tree ranges: structure information for building tensors
145145# --------------------------------------------------------------------------
146- struct FusionBlockStructure{I,F₁,F₂}
146+ struct FusionBlockStructure{I,N, F₁,F₂}
147147 totaldim:: Int
148148 blockstructure:: SectorDict{I,Tuple{Tuple{Int,Int},UnitRange{Int}}}
149149 fusiontreelist:: Vector{Tuple{F₁,F₂}}
150- fusiontreestructure:: Vector{Tuple{Tuple{Int ,Int},Tuple{Int ,Int},Int}}
150+ fusiontreestructure:: Vector{Tuple{NTuple{N ,Int},NTuple{N ,Int},Int}}
151151 fusiontreeindices:: FusionTreeDict{Tuple{F₁,F₂},Int}
152152end
153153
@@ -174,7 +174,7 @@ function fusionblockstructure(W::HomSpace, ::NoCache)
174174 # output structure
175175 blockstructure = SectorDict {I,Tuple{Tuple{Int,Int},UnitRange{Int}}} () # size, range
176176 fusiontreelist = Vector {Tuple{F₁,F₂}} ()
177- fusiontreestructure = Vector {Tuple{Tuple{Int ,Int},Tuple{Int ,Int},Int}} () # size, strides, offset
177+ fusiontreestructure = Vector {Tuple{NTuple{N₁ + N₂ ,Int},NTuple{N₁ + N₂ ,Int},Int}} () # size, strides, offset
178178
179179 # temporary data structures
180180 splittingtrees = Vector {F₁} ()
@@ -203,7 +203,10 @@ function fusionblockstructure(W::HomSpace, ::NoCache)
203203 for (f₁, (offset₁, d₁)) in zip (splittingtrees, splittingstructure)
204204 push! (fusiontreelist, (f₁, f₂))
205205 totaloffset = blockoffset + offset₂ * blockdim₁ + offset₁
206- push! (fusiontreestructure, ((d₁, d₂), strides, totaloffset))
206+ subsz = (dims (codom, f₁. uncoupled)... , dims (dom, f₂. uncoupled)... )
207+ @assert ! any (isequal (0 ), subsz)
208+ substr = _subblock_strides (subsz, (d₁, d₂), strides)
209+ push! (fusiontreestructure, (subsz, substr, totaloffset))
207210 end
208211 offset₂ += d₂
209212 end
@@ -221,22 +224,28 @@ function fusionblockstructure(W::HomSpace, ::NoCache)
221224 fusiontreeindices[f₁₂] = i
222225 end
223226 totaldim = blockoffset
224- structure = FusionBlockStructure {I,F₁,F₂} (totaldim, blockstructure,
225- fusiontreelist, fusiontreestructure,
226- fusiontreeindices)
227+ structure = FusionBlockStructure (totaldim, blockstructure,
228+ fusiontreelist, fusiontreestructure,
229+ fusiontreeindices)
227230 return structure
228231end
229232
233+ function _subblock_strides (subsz, sz, str)
234+ sz_simplify = Strided. StridedViews. _simplifydims (sz, str)
235+ return Strided. StridedViews. _computereshapestrides (subsz, sz_simplify... )
236+ end
237+
230238function fusionblockstructure (W:: HomSpace , :: TaskLocalCache{D} ) where {D}
231239 cache:: D = get! (task_local_storage (), :_local_tensorstructure_cache ) do
232240 return D ()
233241 end
234242 N₁ = length (codomain (W))
235243 N₂ = length (domain (W))
244+ N = N₁ + N₂
236245 I = sectortype (W)
237246 F₁ = fusiontreetype (I, N₁)
238247 F₂ = fusiontreetype (I, N₂)
239- structure:: FusionBlockStructure{I,F₁,F₂} = get! (cache, W) do
248+ structure:: FusionBlockStructure{I,N, F₁,F₂} = get! (cache, W) do
240249 return fusionblockstructure (W, NoCache ())
241250 end
242251 return structure
@@ -248,10 +257,11 @@ function fusionblockstructure(W::HomSpace, ::GlobalLRUCache)
248257 cache = GLOBAL_FUSIONBLOCKSTRUCTURE_CACHE
249258 N₁ = length (codomain (W))
250259 N₂ = length (domain (W))
260+ N = N₁ + N₂
251261 I = sectortype (W)
252262 F₁ = fusiontreetype (I, N₁)
253263 F₂ = fusiontreetype (I, N₂)
254- structure:: FusionBlockStructure{I,F₁,F₂} = get! (cache, W) do
264+ structure:: FusionBlockStructure{I,N, F₁,F₂} = get! (cache, W) do
255265 return fusionblockstructure (W, NoCache ())
256266 end
257267 return structure
0 commit comments