@@ -534,14 +534,10 @@ function _add_abelian_kernel_threaded!(tdst, tsrc, p, transformer::AbelianTreeTr
534534end
535535
536536function _add_transform_single! (tdst, tsrc, p,
537- (basistransform, structures_dst, structures_src) ,
537+ (coeff, struct_dst, struct_src) :: _AbelianTransformerData ,
538538 α, β, backend... )
539- structure_dst = structures_dst isa Vector ? only (structures_dst) : structures_dst
540- structure_src = structures_src isa Vector ? only (structures_src) : structures_src
541- coeff = basistransform isa Number ? basistransform : only (basistransform)
542-
543- subblock_dst = StridedView (tdst. data, structure_dst... )
544- subblock_src = StridedView (tsrc. data, structure_src... )
539+ subblock_dst = StridedView (tdst. data, struct_dst... )
540+ subblock_src = StridedView (tsrc. data, struct_src... )
545541 TO. tensoradd! (subblock_dst, subblock_src, p, false , α * coeff, β, backend... )
546542 return nothing
547543end
@@ -630,27 +626,40 @@ function _add_general_kernel_nonthreaded!(tdst, tsrc, p, transformer, α, β, ba
630626 return nothing
631627end
632628
629+ function _add_transform_single! (tdst, tsrc, p,
630+ (basistransform, structs_dst,
631+ structs_src):: _GenericTransformerData ,
632+ α, β, backend... )
633+ struct_dst = (structs_dst[1 ], only (structs_dst[2 ])... )
634+ struct_src = (structs_src[1 ], only (structs_src[2 ])... )
635+ coeff = only (basistransform)
636+ _add_transform_single! (tdst, tsrc, p, (coeff, struct_dst, struct_src), α, β, backend... )
637+ return nothing
638+ end
639+
633640function _add_transform_multi! (tdst, tsrc, p,
634- (basistransform, structures_dst, structures_src),
641+ (basistransform, (sz_dst, structs_dst),
642+ (sz_src, structs_src)),
635643 (buffer1, buffer2), α, β, backend... )
636644 rows, cols = size (basistransform)
637- sz_src = first (first (structures_src))
638645 blocksize = prod (sz_src)
646+ matsize = (prod (TupleTools. getindices (sz_src, codomainind (tsrc))),
647+ prod (TupleTools. getindices (sz_src, domainind (tsrc))))
639648
640649 # Filling up a buffer with contiguous data
641650 buffer_src = StridedView (buffer2, (blocksize, cols), (1 , blocksize), 0 )
642- for (i, structure_src ) in enumerate (structures_src )
643- subblock_src = StridedView (tsrc. data, structure_src ... )
644- copy! ( sreshape ( buffer_src[:, i], sz_src) , subblock_src)
651+ for (i, struct_src ) in enumerate (structs_src )
652+ subblock_src = sreshape ( StridedView (tsrc. data, sz_src, struct_src ... ), matsize )
653+ copyto! ( buffer_src[:, i], subblock_src)
645654 end
646655
647656 # Resummation into a second buffer using BLAS
648657 buffer_dst = StridedView (buffer1, (blocksize, rows), (1 , blocksize), 0 )
649658 mul! (buffer_dst, buffer_src, basistransform, α, Zero ())
650659
651660 # Filling up the output
652- for (i, structure_dst ) in enumerate (structures_dst )
653- subblock_dst = StridedView (tdst. data, structure_dst ... )
661+ for (i, struct_dst ) in enumerate (structs_dst )
662+ subblock_dst = StridedView (tdst. data, sz_dst, struct_dst ... )
654663 bufblock_dst = sreshape (buffer_dst[:, i], sz_src)
655664 TO. tensoradd! (subblock_dst, bufblock_dst, p, false , One (), β, backend... )
656665 end
0 commit comments