Skip to content

Commit e3f68ad

Browse files
committed
Refactor treetransformer to make use of vectorized implementation
1 parent 9ccc644 commit e3f68ad

File tree

1 file changed

+18
-31
lines changed

1 file changed

+18
-31
lines changed

src/tensors/treetransformers.jl

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -62,39 +62,26 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
6262
fusionstructure_dst = structure_dst.fusiontreestructure
6363
structure_src = fusionblockstructure(Vsrc)
6464
fusionstructure_src = structure_src.fusiontreestructure
65-
I = sectortype(Vsrc)
66-
67-
uncoupleds_src = map(structure_src.fusiontreelist) do (f₁, f₂)
68-
return TupleTools.vcat(f₁.uncoupled, dual.(f₂.uncoupled))
69-
end
70-
uncoupleds_src_unique = unique(uncoupleds_src)
71-
72-
uncoupleds_dst = map(structure_dst.fusiontreelist) do (f₁, f₂)
73-
return TupleTools.vcat(f₁.uncoupled, dual.(f₂.uncoupled))
74-
end
7565

66+
I = sectortype(Vsrc)
7667
T = sectorscalartype(I)
7768
N = numind(Vdst)
78-
L = length(uncoupleds_src_unique)
79-
data = Vector{_GenericTransformerData{T,N}}(undef, L)
69+
data = Vector{_GenericTransformerData{T,N}}()
8070

81-
# TODO: this can be multithreaded
82-
for (i, uncoupled) in enumerate(uncoupleds_src_unique)
83-
inds_src = findall(==(uncoupled), uncoupleds_src)
84-
fusiontrees_outer_src = structure_src.fusiontreelist[inds_src]
71+
isdual_src = (map(isdual, codomain(Vsrc).spaces), map(isdual, domain(Vsrc).spaces))
72+
for cod_uncoupled_src in sectors(codomain(Vsrc)),
73+
dom_uncoupled_src in sectors(domain(Vsrc))
8574

86-
uncoupled_dst = TupleTools.getindices(uncoupled, (p[1]..., p[2]...))
87-
inds_dst = findall(==(uncoupled_dst), uncoupleds_dst)
75+
fs_src = OuterTreeIterator((cod_uncoupled_src, dom_uncoupled_src), isdual_src)
76+
trees_src = fusiontrees(fs_src)
77+
isempty(trees_src) && continue
8878

89-
fusiontrees_outer_dst = structure_dst.fusiontreelist[inds_dst]
79+
fs_dst, U = transform(fs_src)
80+
matrix = copy(transpose(U)) # TODO: should we avoid this
9081

91-
matrix = zeros(sectorscalartype(I), length(inds_dst), length(inds_src))
92-
for (row, (f₁, f₂)) in enumerate(fusiontrees_outer_src)
93-
for ((f₃, f₄), coeff) in transform(f₁, f₂)
94-
col = findfirst(==((f₃, f₄)), fusiontrees_outer_dst)::Int
95-
matrix[row, col] = coeff
96-
end
97-
end
82+
inds_src = map(Base.Fix1(getindex, structure_src.fusiontreeindices), trees_src)
83+
trees_dst = fusiontrees(fs_dst)
84+
inds_dst = map(Base.Fix1(getindex, structure_dst.fusiontreeindices), trees_dst)
9885

9986
# size is shared between blocks, so repack:
10087
# from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...])
@@ -104,7 +91,7 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
10491
@debug("Created recoupling block for uncoupled: $uncoupled",
10592
sz = size(matrix), sparsity = count(!iszero, matrix) / length(matrix))
10693

107-
data[i] = (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src))
94+
push!(data, (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src)))
10895
end
10996

11097
transformer = GenericTreeTransformer{T,N}(data)
@@ -166,29 +153,29 @@ end
166153

167154
# braid is special because it has levels
168155
function treebraider(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple, levels)
169-
return fusiontreetransform((f1, f2)) = braid((f1, f2), levels, p)
156+
return fusiontreetransform(f) = braid(f, levels, p)
170157
end
171158
function treebraider(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple, levels)
172159
return treebraider(space(tdst), space(tsrc), p, levels)
173160
end
174161
@cached function treebraider(Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple,
175162
levels)::treetransformertype(Vdst, Vsrc)
176-
fusiontreebraider((f1, f2)) = braid((f1, f2), levels, p)
163+
fusiontreebraider(f) = braid(f, levels, p)
177164
return TreeTransformer(fusiontreebraider, p, Vdst, Vsrc)
178165
end
179166

180167
for (transform, treetransformer) in
181168
((:permute, :treepermuter), (:transpose, :treetransposer))
182169
@eval begin
183170
function $treetransformer(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple)
184-
return fusiontreetransform(f1, f2) = $transform((f1, f2), p)
171+
return fusiontreetransform(f) = $transform(f, p)
185172
end
186173
function $treetransformer(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple)
187174
return $treetransformer(space(tdst), space(tsrc), p)
188175
end
189176
@cached function $treetransformer(Vdst::TensorMapSpace, Vsrc::TensorMapSpace,
190177
p::Index2Tuple)::treetransformertype(Vdst, Vsrc)
191-
fusiontreetransform((f1, f2)) = $transform((f1, f2), p)
178+
fusiontreetransform(f) = $transform(f, p)
192179
return TreeTransformer(fusiontreetransform, p, Vdst, Vsrc)
193180
end
194181
end

0 commit comments

Comments
 (0)