Skip to content

Commit e8d6287

Browse files
committed
Refactor buffer allocation
1 parent 4c39252 commit e8d6287

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

src/tensors/indexmanipulations.jl

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -536,11 +536,7 @@ end
536536
function _add_transform_nonthreaded!(tdst, tsrc, p, transformer::GenericTreeTransformer,
537537
α, β, backend...)
538538
# preallocate buffers
539-
buffersize = maximum(transformer.data) do (_, structures_dst, _)
540-
return prod(structures_dst[1][1])
541-
end
542-
buffer1 = similar(tsrc.data, buffersize)
543-
buffer2 = similar(tdst.data, buffersize)
539+
buffers = allocate_buffers(tdst, tsrc, transformer)
544540

545541
# TODO: this could be multithreaded
546542
for subtransformer in transformer.data
@@ -549,7 +545,7 @@ function _add_transform_nonthreaded!(tdst, tsrc, p, transformer::GenericTreeTran
549545
_add_transform_single!(tdst, tsrc, p, α, β, subtransformer, backend...)
550546
else
551547
_add_transform_multi!(tdst, tsrc, p, α, β, subtransformer,
552-
(buffer1, buffer2), backend...)
548+
buffers, backend...)
553549
end
554550
end
555551
return nothing
@@ -558,17 +554,14 @@ end
558554
function _add_transform_threaded!(tdst, tsrc, p, transformer::GenericTreeTransformer,
559555
α, β, backend...;
560556
ntasks::Int=get_num_transformer_threads())
561-
buffersize = maximum(transformer.data) do (_, structures_dst, _)
562-
return prod(structures_dst[1][1])
563-
end
557+
buffersz = buffersize(transformer)
564558
nblocks = length(transformer.data)
565559

566560
counter = Threads.Atomic{Int}(1)
567561
Threads.@sync for _ in 1:min(ntasks, nblocks)
568562
Threads.@spawn begin
569563
# preallocate buffers for each task
570-
buffer1 = similar(tsrc.data, buffersize)
571-
buffer2 = similar(tdst.data, buffersize)
564+
buffers = allocate_buffers(tdst, tsrc, transformer)
572565

573566
while true
574567
local_counter = Threads.atomic_add!(counter, 1)
@@ -577,7 +570,7 @@ function _add_transform_threaded!(tdst, tsrc, p, transformer::GenericTreeTransfo
577570
if length(subtransformer[1]) == 1
578571
_add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...)
579572
else
580-
_add_transform_multi!(tdst, tsrc, p, subtransformer, (buffer1, buffer2),
573+
_add_transform_multi!(tdst, tsrc, p, subtransformer, buffers,
581574
α, β, backend...)
582575
end
583576
end
@@ -610,14 +603,14 @@ function _add_transform_multi!(tdst, tsrc, p,
610603
blocksize = prod(sz_src)
611604

612605
# Filling up a buffer with contiguous data
613-
buffer_src = StridedView(buffer1, (blocksize, cols), (1, blocksize), 0)
606+
buffer_src = StridedView(buffer2, (blocksize, cols), (1, blocksize), 0)
614607
for (i, structure_src) in enumerate(structures_src)
615608
subblock_src = StridedView(tsrc.data, structure_src...)
616609
copyto!(@view(buffer_src[:, i]), subblock_src)
617610
end
618611

619612
# Resummation into a second buffer using BLAS
620-
buffer_dst = StridedView(buffer2, (blocksize, rows), (1, blocksize), 0)
613+
buffer_dst = StridedView(buffer1, (blocksize, rows), (1, blocksize), 0)
621614
mul!(buffer_dst, buffer_src, basistransform, α, Zero())
622615

623616
# Filling up the output

src/tensors/treetransformers.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,18 @@ function _transformer_weight((matrix, structures_dst, structures_src))
9292
return size(matrix, 1) * prod(structures_dst[1][1])
9393
end
9494

95+
function buffersize(transformer::GenericTreeTransformer)
96+
return maximum(transformer.data) do (basistransform, structures_dst, _)
97+
return prod(structures_dst[1][1]) * max(size(basistransform)...)
98+
end
99+
end
100+
101+
function allocate_buffers(tdst::TensorMap, tsrc::TensorMap,
102+
transformer::GenericTreeTransformer)
103+
sz = buffersize(transformer)
104+
return similar(tdst.data, sz), similar(tsrc.data, sz)
105+
end
106+
95107
function treetransformertype(Vdst, Vsrc)
96108
I = sectortype(Vdst)
97109
I === Trivial && return TrivialTreeTransformer

0 commit comments

Comments
 (0)