@@ -62,39 +62,26 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
6262 fusionstructure_dst = structure_dst. fusiontreestructure
6363 structure_src = fusionblockstructure (Vsrc)
6464 fusionstructure_src = structure_src. fusiontreestructure
65- I = sectortype (Vsrc)
66-
67- uncoupleds_src = map (structure_src. fusiontreelist) do (f₁, f₂)
68- return TupleTools. vcat (f₁. uncoupled, dual .(f₂. uncoupled))
69- end
70- uncoupleds_src_unique = unique (uncoupleds_src)
71-
72- uncoupleds_dst = map (structure_dst. fusiontreelist) do (f₁, f₂)
73- return TupleTools. vcat (f₁. uncoupled, dual .(f₂. uncoupled))
74- end
7565
66+ I = sectortype (Vsrc)
7667 T = sectorscalartype (I)
7768 N = numind (Vdst)
78- L = length (uncoupleds_src_unique)
79- data = Vector {_GenericTransformerData{T,N}} (undef, L)
69+ data = Vector {_GenericTransformerData{T,N}} ()
8070
81- # TODO : this can be multithreaded
82- for (i, uncoupled) in enumerate (uncoupleds_src_unique)
83- inds_src = findall (== (uncoupled), uncoupleds_src)
84- fusiontrees_outer_src = structure_src. fusiontreelist[inds_src]
71+ isdual_src = (map (isdual, codomain (Vsrc). spaces), map (isdual, domain (Vsrc). spaces))
72+ for cod_uncoupled_src in sectors (codomain (Vsrc)),
73+ dom_uncoupled_src in sectors (domain (Vsrc))
8574
86- uncoupled_dst = TupleTools. getindices (uncoupled, (p[1 ]. .. , p[2 ]. .. ))
87- inds_dst = findall (== (uncoupled_dst), uncoupleds_dst)
75+ fs_src = OuterTreeIterator ((cod_uncoupled_src, dom_uncoupled_src), isdual_src)
76+ trees_src = fusiontrees (fs_src)
77+ isempty (trees_src) && continue
8878
89- fusiontrees_outer_dst = structure_dst. fusiontreelist[inds_dst]
79+ fs_dst, U = transform (fs_src)
80+ matrix = copy (transpose (U)) # TODO : should we avoid this
9081
91- matrix = zeros (sectorscalartype (I), length (inds_dst), length (inds_src))
92- for (row, (f₁, f₂)) in enumerate (fusiontrees_outer_src)
93- for ((f₃, f₄), coeff) in transform (f₁, f₂)
94- col = findfirst (== ((f₃, f₄)), fusiontrees_outer_dst):: Int
95- matrix[row, col] = coeff
96- end
97- end
82+ inds_src = map (Base. Fix1 (getindex, structure_src. fusiontreeindices), trees_src)
83+ trees_dst = fusiontrees (fs_dst)
84+ inds_dst = map (Base. Fix1 (getindex, structure_dst. fusiontreeindices), trees_dst)
9885
9986 # size is shared between blocks, so repack:
10087 # from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...])
@@ -104,7 +91,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
10491 @debug (" Created recoupling block for uncoupled: $uncoupled " ,
10592 sz = size (matrix), sparsity = count (! iszero, matrix) / length (matrix))
10693
107- data[i] = (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src))
94+ push! ( data, (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src) ))
10895 end
10996
11097 transformer = GenericTreeTransformer {T,N} (data)
@@ -166,29 +153,29 @@ end
166153
167154# braid is special because it has levels
168155function treebraider (:: AbstractTensorMap , :: AbstractTensorMap , p:: Index2Tuple , levels)
169- return fusiontreetransform ((f1, f2)) = braid ((f1, f2) , levels, p)
156+ return fusiontreetransform (f) = braid (f , levels, p)
170157end
171158function treebraider (tdst:: TensorMap , tsrc:: TensorMap , p:: Index2Tuple , levels)
172159 return treebraider (space (tdst), space (tsrc), p, levels)
173160end
174161@cached function treebraider (Vdst:: TensorMapSpace , Vsrc:: TensorMapSpace , p:: Index2Tuple ,
175162 levels):: treetransformertype (Vdst, Vsrc)
176- fusiontreebraider ((f1, f2)) = braid ((f1, f2) , levels, p)
163+ fusiontreebraider (f) = braid (f , levels, p)
177164 return TreeTransformer (fusiontreebraider, p, Vdst, Vsrc)
178165end
179166
180167for (transform, treetransformer) in
181168 ((:permute , :treepermuter ), (:transpose , :treetransposer ))
182169 @eval begin
183170 function $treetransformer (:: AbstractTensorMap , :: AbstractTensorMap , p:: Index2Tuple )
184- return fusiontreetransform (f1, f2 ) = $ transform ((f1, f2) , p)
171+ return fusiontreetransform (f ) = $ transform (f , p)
185172 end
186173 function $treetransformer (tdst:: TensorMap , tsrc:: TensorMap , p:: Index2Tuple )
187174 return $ treetransformer (space (tdst), space (tsrc), p)
188175 end
189176 @cached function $treetransformer (Vdst:: TensorMapSpace , Vsrc:: TensorMapSpace ,
190177 p:: Index2Tuple ):: treetransformertype (Vdst, Vsrc)
191- fusiontreetransform ((f1, f2)) = $ transform ((f1, f2) , p)
178+ fusiontreetransform (f) = $ transform (f , p)
192179 return TreeTransformer (fusiontreetransform, p, Vdst, Vsrc)
193180 end
194181 end
0 commit comments