@@ -62,39 +62,26 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
6262 fusionstructure_dst = structure_dst. fusiontreestructure
6363 structure_src = fusionblockstructure(Vsrc)
6464 fusionstructure_src = structure_src. fusiontreestructure
65- I = sectortype(Vsrc)
66-
67- uncoupleds_src = map(structure_src. fusiontreelist) do (f₁, f₂)
68- return TupleTools. vcat(f₁. uncoupled, dual.(f₂. uncoupled))
69- end
70- uncoupleds_src_unique = unique(uncoupleds_src)
71-
72- uncoupleds_dst = map(structure_dst. fusiontreelist) do (f₁, f₂)
73- return TupleTools. vcat(f₁. uncoupled, dual.(f₂. uncoupled))
74- end
7565
66+ I = sectortype(Vsrc)
7667 T = sectorscalartype(I)
7768 N = numind(Vdst)
78- L = length(uncoupleds_src_unique)
79- data = Vector{_GenericTransformerData{T,N}}(undef, L)
69+ data = Vector{_GenericTransformerData{T,N}}()
8070
81- # TODO : this can be multithreaded
82- for (i, uncoupled) in enumerate(uncoupleds_src_unique)
83- inds_src = findall(== (uncoupled), uncoupleds_src)
84- fusiontrees_outer_src = structure_src. fusiontreelist[inds_src]
71+ isdual_src = (map(isdual, codomain(Vsrc). spaces), map(isdual, domain(Vsrc). spaces))
72+ for cod_uncoupled_src in sectors(codomain(Vsrc)),
73+ dom_uncoupled_src in sectors(domain(Vsrc))
8574
86- uncoupled_dst = TupleTools. getindices(uncoupled, (p[1 ]. .. , p[2 ]. .. ))
87- inds_dst = findall(== (uncoupled_dst), uncoupleds_dst)
75+ fs_src = OuterTreeIterator((cod_uncoupled_src, dom_uncoupled_src), isdual_src)
76+ trees_src = fusiontrees(fs_src)
77+ isempty(trees_src) && continue
8878
89- fusiontrees_outer_dst = structure_dst. fusiontreelist[inds_dst]
79+ fs_dst, U = transform(fs_src)
80+ matrix = copy(transpose(U)) # TODO : should we avoid this
9081
91- matrix = zeros(sectorscalartype(I), length(inds_dst), length(inds_src))
92- for (row, (f₁, f₂)) in enumerate(fusiontrees_outer_src)
93- for ((f₃, f₄), coeff) in transform(f₁, f₂)
94- col = findfirst(== ((f₃, f₄)), fusiontrees_outer_dst):: Int
95- matrix[row, col] = coeff
96- end
97- end
82+ inds_src = map(Base. Fix1(getindex, structure_src. fusiontreeindices), trees_src)
83+ trees_dst = fusiontrees(fs_dst)
84+ inds_dst = map(Base. Fix1(getindex, structure_dst. fusiontreeindices), trees_dst)
9885
9986 # size is shared between blocks, so repack:
10087 # from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...])
@@ -104,7 +91,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
10491 @debug(" Created recoupling block for uncoupled: $uncoupled " ,
10592 sz = size(matrix), sparsity = count(! iszero, matrix) / length(matrix))
10693
107- data[i] = (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src))
94+ push!( data, (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src) ))
10895 end
10996
11097 transformer = GenericTreeTransformer{T,N}(data)
@@ -166,29 +153,29 @@ end
166153
167154# braid is special because it has levels
168155function treebraider(:: AbstractTensorMap , :: AbstractTensorMap , p:: Index2Tuple , levels)
169- return fusiontreetransform((f1, f2)) = braid((f1, f2) , levels, p)
156+ return fusiontreetransform(f) = braid(f , levels, p)
170157end
171158function treebraider(tdst:: TensorMap , tsrc:: TensorMap , p:: Index2Tuple , levels)
172159 return treebraider(space(tdst), space(tsrc), p, levels)
173160end
174161@cached function treebraider(Vdst:: TensorMapSpace , Vsrc:: TensorMapSpace , p:: Index2Tuple ,
175162 levels):: treetransformertype (Vdst, Vsrc)
176- fusiontreebraider((f1, f2)) = braid((f1, f2) , levels, p)
163+ fusiontreebraider(f) = braid(f , levels, p)
177164 return TreeTransformer(fusiontreebraider, p, Vdst, Vsrc)
178165end
179166
180167for (transform, treetransformer) in
181168 ((:permute, :treepermuter), (:transpose, :treetransposer))
182169 @eval begin
183170 function $ treetransformer(:: AbstractTensorMap , :: AbstractTensorMap , p:: Index2Tuple )
184- return fusiontreetransform(f1, f2 ) = $ transform((f1, f2) , p)
171+ return fusiontreetransform(f ) = $ transform(f , p)
185172 end
186173 function $ treetransformer(tdst:: TensorMap , tsrc:: TensorMap , p:: Index2Tuple )
187174 return $ treetransformer(space(tdst), space(tsrc), p)
188175 end
189176 @cached function $ treetransformer(Vdst:: TensorMapSpace , Vsrc:: TensorMapSpace ,
190177 p:: Index2Tuple ):: treetransformertype (Vdst, Vsrc)
191- fusiontreetransform((f1, f2)) = $ transform((f1, f2) , p)
178+ fusiontreetransform(f) = $ transform(f , p)
192179 return TreeTransformer(fusiontreetransform, p, Vdst, Vsrc)
193180 end
194181 end
0 commit comments