Skip to content

Commit fff2d5a

Browse files
committed
add multithreaded treetransformer implementation
1 parent 731a832 commit fff2d5a

File tree

1 file changed

+71
-18
lines changed

1 file changed

+71
-18
lines changed

src/tensors/treetransformers.jl

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,32 +66,85 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
6666
I = sectortype(Vsrc)
6767
T = sectorscalartype(I)
6868
N = numind(Vdst)
69-
data = Vector{_GenericTransformerData{T,N}}()
7069

7170
isdual_src = (map(isdual, codomain(Vsrc).spaces), map(isdual, domain(Vsrc).spaces))
72-
for cod_uncoupled_src in sectors(codomain(Vsrc)),
73-
dom_uncoupled_src in sectors(domain(Vsrc))
7471

75-
fs_src = FusionTreeBlock{I}((cod_uncoupled_src, dom_uncoupled_src), isdual_src)
76-
trees_src = fusiontrees(fs_src)
77-
isempty(trees_src) && continue
72+
nthreads = get_num_transformer_threads()
73+
if nthreads > 1
74+
fusiontreeblocks = Vector{FusionTreeBlock{I,N₁,N₂,fusiontreetype(I, N₁, N₂)}}()
75+
for cod_uncoupled_src in sectors(codomain(Vsrc)),
76+
dom_uncoupled_src in sectors(domain(Vsrc))
77+
78+
fs_src = FusionTreeBlock{I}((cod_uncoupled_src, dom_uncoupled_src), isdual_src)
79+
trees_src = fusiontrees(fs_src)
80+
if !isempty(trees_src)
81+
push!(fusiontreeblocks, fs_src)
82+
end
83+
end
84+
85+
data = Vector{_GenericTransformerData{T,N}}(undef, length(fusiontreeblocks))
86+
counter = Threads.Atomic{Int}(1)
87+
Threads.@sync for _ in 1:min(nthreads, length(fusiontreeblocks))
88+
Threads.@spawn begin
89+
while true
90+
local_counter = Threads.atomic_add!(counter, 1)
91+
local_counter > nblocks && break
92+
fs_src = fusiontreeblocks[local_counter]
93+
fs_dst, U = transform(fs_src)
94+
matrix = copy(transpose(U)) # TODO: should we avoid this
95+
96+
inds_src = map(Base.Fix1(getindex, structure_src.fusiontreeindices),
97+
trees_src)
98+
trees_dst = fusiontrees(fs_dst)
99+
inds_dst = map(Base.Fix1(getindex, structure_dst.fusiontreeindices),
100+
trees_dst)
101+
102+
# size is shared between blocks, so repack:
103+
# from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...])
104+
sz_src, newstructs_src = repack_transformer_structure(fusionstructure_src,
105+
inds_src)
106+
sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst,
107+
inds_dst)
108+
109+
@debug("Created recoupling block for uncoupled: $uncoupled",
110+
sz = size(matrix),
111+
sparsity = count(!iszero, matrix) / length(matrix))
112+
113+
data[local_counter] = (matrix, (sz_dst, newstructs_dst),
114+
(sz_src, newstructs_src))
115+
end
116+
end
117+
end
118+
else
119+
data = Vector{_GenericTransformerData{T,N}}()
78120

79-
fs_dst, U = transform(fs_src)
80-
matrix = copy(transpose(U)) # TODO: should we avoid this
121+
isdual_src = (map(isdual, codomain(Vsrc).spaces), map(isdual, domain(Vsrc).spaces))
122+
for cod_uncoupled_src in sectors(codomain(Vsrc)),
123+
dom_uncoupled_src in sectors(domain(Vsrc))
81124

82-
inds_src = map(Base.Fix1(getindex, structure_src.fusiontreeindices), trees_src)
83-
trees_dst = fusiontrees(fs_dst)
84-
inds_dst = map(Base.Fix1(getindex, structure_dst.fusiontreeindices), trees_dst)
125+
fs_src = FusionTreeBlock{I}((cod_uncoupled_src, dom_uncoupled_src), isdual_src)
126+
trees_src = fusiontrees(fs_src)
127+
isempty(trees_src) && continue
85128

86-
# size is shared between blocks, so repack:
87-
# from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...])
88-
sz_src, newstructs_src = repack_transformer_structure(fusionstructure_src, inds_src)
89-
sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst, inds_dst)
129+
fs_dst, U = transform(fs_src)
130+
matrix = copy(transpose(U)) # TODO: should we avoid this
90131

91-
@debug("Created recoupling block for uncoupled: $uncoupled",
92-
sz = size(matrix), sparsity = count(!iszero, matrix) / length(matrix))
132+
inds_src = map(Base.Fix1(getindex, structure_src.fusiontreeindices), trees_src)
133+
trees_dst = fusiontrees(fs_dst)
134+
inds_dst = map(Base.Fix1(getindex, structure_dst.fusiontreeindices), trees_dst)
93135

94-
push!(data, (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src)))
136+
# size is shared between blocks, so repack:
137+
# from [(sz, strides, offset), ...] to (sz, [(strides, offset), ...])
138+
sz_src, newstructs_src = repack_transformer_structure(fusionstructure_src,
139+
inds_src)
140+
sz_dst, newstructs_dst = repack_transformer_structure(fusionstructure_dst,
141+
inds_dst)
142+
143+
@debug("Created recoupling block for uncoupled: $uncoupled",
144+
sz = size(matrix), sparsity = count(!iszero, matrix) / length(matrix))
145+
146+
push!(data, (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src)))
147+
end
95148
end
96149

97150
transformer = GenericTreeTransformer{T,N}(data)

0 commit comments

Comments
 (0)