536536function _add_transform_nonthreaded! (tdst, tsrc, p, transformer:: GenericTreeTransformer ,
537537 α, β, backend... )
538538 # preallocate buffers
539- buffersize = maximum (transformer. data) do (_, structures_dst, _)
540- return prod (structures_dst[1 ][1 ])
541- end
542- buffer1 = similar (tsrc. data, buffersize)
543- buffer2 = similar (tdst. data, buffersize)
539+ buffers = allocate_buffers (tdst, tsrc, transformer)
544540
545541 # TODO : this could be multithreaded
546542 for subtransformer in transformer. data
@@ -549,7 +545,7 @@ function _add_transform_nonthreaded!(tdst, tsrc, p, transformer::GenericTreeTran
549545 _add_transform_single! (tdst, tsrc, p, α, β, subtransformer, backend... )
550546 else
551547 _add_transform_multi! (tdst, tsrc, p, α, β, subtransformer,
552- (buffer1, buffer2) , backend... )
548+ buffers , backend... )
553549 end
554550 end
555551 return nothing
@@ -558,17 +554,14 @@ end
558554function _add_transform_threaded! (tdst, tsrc, p, transformer:: GenericTreeTransformer ,
559555 α, β, backend... ;
560556 ntasks:: Int = get_num_transformer_threads ())
561- buffersize = maximum (transformer. data) do (_, structures_dst, _)
562- return prod (structures_dst[1 ][1 ])
563- end
557+ buffersz = buffersize (transformer)
564558 nblocks = length (transformer. data)
565559
566560 counter = Threads. Atomic {Int} (1 )
567561 Threads. @sync for _ in 1 : min (ntasks, nblocks)
568562 Threads. @spawn begin
569563 # preallocate buffers for each task
570- buffer1 = similar (tsrc. data, buffersize)
571- buffer2 = similar (tdst. data, buffersize)
564+ buffers = allocate_buffers (tdst, tsrc, transformer)
572565
573566 while true
574567 local_counter = Threads. atomic_add! (counter, 1 )
@@ -577,7 +570,7 @@ function _add_transform_threaded!(tdst, tsrc, p, transformer::GenericTreeTransfo
577570 if length (subtransformer[1 ]) == 1
578571 _add_transform_single! (tdst, tsrc, p, subtransformer, α, β, backend... )
579572 else
580- _add_transform_multi! (tdst, tsrc, p, subtransformer, (buffer1, buffer2) ,
573+ _add_transform_multi! (tdst, tsrc, p, subtransformer, buffers ,
581574 α, β, backend... )
582575 end
583576 end
@@ -610,14 +603,14 @@ function _add_transform_multi!(tdst, tsrc, p,
610603 blocksize = prod (sz_src)
611604
612605 # Filling up a buffer with contiguous data
613- buffer_src = StridedView (buffer1 , (blocksize, cols), (1 , blocksize), 0 )
606+ buffer_src = StridedView (buffer2 , (blocksize, cols), (1 , blocksize), 0 )
614607 for (i, structure_src) in enumerate (structures_src)
615608 subblock_src = StridedView (tsrc. data, structure_src... )
616609 copyto! (@view (buffer_src[:, i]), subblock_src)
617610 end
618611
619612 # Resummation into a second buffer using BLAS
620- buffer_dst = StridedView (buffer2 , (blocksize, rows), (1 , blocksize), 0 )
613+ buffer_dst = StridedView (buffer1 , (blocksize, rows), (1 , blocksize), 0 )
621614 mul! (buffer_dst, buffer_src, basistransform, α, Zero ())
622615
623616 # Filling up the output
0 commit comments