Skip to content

Commit 70caeb9

Browse files
committed
Avoid duplicate storage of sizes
1 parent 071a6b6 commit 70caeb9

File tree

2 files changed

+40
-21
lines changed

2 files changed

+40
-21
lines changed

src/tensors/indexmanipulations.jl

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -534,14 +534,10 @@ function _add_abelian_kernel_threaded!(tdst, tsrc, p, transformer::AbelianTreeTr
534534
end
535535

536536
function _add_transform_single!(tdst, tsrc, p,
537-
(basistransform, structures_dst, structures_src),
537+
(coeff, struct_dst, struct_src)::_AbelianTransformerData,
538538
α, β, backend...)
539-
structure_dst = structures_dst isa Vector ? only(structures_dst) : structures_dst
540-
structure_src = structures_src isa Vector ? only(structures_src) : structures_src
541-
coeff = basistransform isa Number ? basistransform : only(basistransform)
542-
543-
subblock_dst = StridedView(tdst.data, structure_dst...)
544-
subblock_src = StridedView(tsrc.data, structure_src...)
539+
subblock_dst = StridedView(tdst.data, struct_dst...)
540+
subblock_src = StridedView(tsrc.data, struct_src...)
545541
TO.tensoradd!(subblock_dst, subblock_src, p, false, α * coeff, β, backend...)
546542
return nothing
547543
end
@@ -630,27 +626,40 @@ function _add_general_kernel_nonthreaded!(tdst, tsrc, p, transformer, α, β, ba
630626
return nothing
631627
end
632628

629+
function _add_transform_single!(tdst, tsrc, p,
630+
(basistransform, structs_dst,
631+
structs_src)::_GenericTransformerData,
632+
α, β, backend...)
633+
struct_dst = (structs_dst[1], only(structs_dst[2])...)
634+
struct_src = (structs_src[1], only(structs_src[2])...)
635+
coeff = only(basistransform)
636+
_add_transform_single!(tdst, tsrc, p, (coeff, struct_dst, struct_src), α, β, backend...)
637+
return nothing
638+
end
639+
633640
function _add_transform_multi!(tdst, tsrc, p,
634-
(basistransform, structures_dst, structures_src),
641+
(basistransform, (sz_dst, structs_dst),
642+
(sz_src, structs_src)),
635643
(buffer1, buffer2), α, β, backend...)
636644
rows, cols = size(basistransform)
637-
sz_src = first(first(structures_src))
638645
blocksize = prod(sz_src)
646+
matsize = (prod(TupleTools.getindices(sz_src, codomainind(tsrc))),
647+
prod(TupleTools.getindices(sz_src, domainind(tsrc))))
639648

640649
# Filling up a buffer with contiguous data
641650
buffer_src = StridedView(buffer2, (blocksize, cols), (1, blocksize), 0)
642-
for (i, structure_src) in enumerate(structures_src)
643-
subblock_src = StridedView(tsrc.data, structure_src...)
644-
copy!(sreshape(buffer_src[:, i], sz_src), subblock_src)
651+
for (i, struct_src) in enumerate(structs_src)
652+
subblock_src = sreshape(StridedView(tsrc.data, sz_src, struct_src...), matsize)
653+
copyto!(buffer_src[:, i], subblock_src)
645654
end
646655

647656
# Resummation into a second buffer using BLAS
648657
buffer_dst = StridedView(buffer1, (blocksize, rows), (1, blocksize), 0)
649658
mul!(buffer_dst, buffer_src, basistransform, α, Zero())
650659

651660
# Filling up the output
652-
for (i, structure_dst) in enumerate(structures_dst)
653-
subblock_dst = StridedView(tdst.data, structure_dst...)
661+
for (i, struct_dst) in enumerate(structs_dst)
662+
subblock_dst = StridedView(tdst.data, sz_dst, struct_dst...)
654663
bufblock_dst = sreshape(buffer_dst[:, i], sz_src)
655664
TO.tensoradd!(subblock_dst, bufblock_dst, p, false, One(), β, backend...)
656665
end

src/tensors/treetransformers.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,11 @@ function AbelianTreeTransformer(transform, p, Vdst, Vsrc)
4545
return transformer
4646
end
4747

48-
const _GenericTransformerData{T,N} = Tuple{Matrix{T},Vector{StridedStructure{N}},
49-
Vector{StridedStructure{N}}}
48+
const _GenericTransformerData{T,N} = Tuple{Matrix{T},
49+
Tuple{NTuple{N,Int},
50+
Vector{Tuple{NTuple{N,Int},Int}}},
51+
Tuple{NTuple{N,Int},
52+
Vector{Tuple{NTuple{N,Int},Int}}}}
5053

5154
struct GenericTreeTransformer{T,N} <: TreeTransformer
5255
data::Vector{_GenericTransformerData{T,N}}
@@ -90,12 +93,19 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc)
9093
matrix[row, col] = coeff
9194
end
9295
end
96+
97+
structs_src = structure_src.fusiontreestructure[ids_src]
98+
sz_src = structs_src[1][1]
99+
newstructs_src = map(x -> (x[2], x[3]), structs_src)
100+
101+
structs_dst = structure_dst.fusiontreestructure[ids_dst]
102+
sz_dst = structs_dst[1][1]
103+
newstructs_dst = map(x -> (x[2], x[3]), structs_dst)
104+
93105
@debug("Created recoupling block for uncoupled: $uncoupled",
94106
sz = size(matrix), sparsity = count(!iszero, matrix) / length(matrix))
95107

96-
data[i] = (matrix,
97-
structure_dst.fusiontreestructure[ids_dst],
98-
structure_src.fusiontreestructure[ids_src])
108+
data[i] = (matrix, (sz_dst, newstructs_dst), (sz_src, newstructs_src))
99109
end
100110

101111
transformer = GenericTreeTransformer{T,N}(data)
@@ -116,7 +126,7 @@ end
116126

117127
function buffersize(transformer::GenericTreeTransformer)
118128
return maximum(transformer.data; init=0) do (basistransform, structures_dst, _)
119-
return prod(structures_dst[1][1]) * size(basistransform, 1)
129+
return prod(structures_dst[1]) * size(basistransform, 1)
120130
end
121131
end
122132

@@ -199,5 +209,5 @@ end
199209
# Note that it might be the case that the permutations are dominant, in which case the
200210
# actual cost model would scale like L x length(subblock)
201211
function _transformer_weight((mat, structs_dst, structs_src)::_GenericTransformerData)
202-
return length(mat) * prod(structs_dst[1][1])
212+
return length(mat) * prod(structs_dst[1])
203213
end

0 commit comments

Comments
 (0)