Skip to content

Commit 65141f2

Browse files
committed
Change cache key to space instead of tensor
1 parent 6bf12af commit 65141f2

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

src/tensors/treetransformers.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,13 @@ end
7373
function treebraider(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple, levels)
7474
return fusiontreetransform(f1, f2) = braid(f1, f2, levels..., p...)
7575
end
76-
@cached function treebraider(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple,
77-
levels)::treetransformertype(space(tdst),
78-
space(tsrc))
76+
function treebraider(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple, levels)
77+
return treebraider(space(tdst), space(tsrc), p, levels)
78+
end
79+
@cached function treebraider(Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple,
80+
levels)::treetransformertype(Vdst, Vsrc)
7981
fusiontreebraider(f1, f2) = braid(f1, f2, levels..., p...)
80-
return TreeTransformer(fusiontreebraider, space(tsrc), space(tdst))
82+
return TreeTransformer(fusiontreebraider, Vdst, Vsrc)
8183
end
8284

8385
for (transform, treetransformer) in
@@ -86,11 +88,13 @@ for (transform, treetransformer) in
8688
function $treetransformer(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple)
8789
return fusiontreetransform(f1, f2) = $transform(f1, f2, p...)
8890
end
89-
@cached function $treetransformer(tdst::TensorMap, tsrc::TensorMap,
90-
p::Index2Tuple)::treetransformertype(space(tdst),
91-
space(tsrc))
91+
function $treetransformer(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple)
92+
return $treetransformer(space(tdst), space(tsrc), p)
93+
end
94+
@cached function $treetransformer(Vdst::TensorMapSpace, Vsrc::TensorMapSpace,
95+
p::Index2Tuple)::treetransformertype(Vdst, Vsrc)
9296
fusiontreetransform(f1, f2) = $transform(f1, f2, p...)
93-
return TreeTransformer(fusiontreetransform, space(tsrc), space(tdst))
97+
return TreeTransformer(fusiontreetransform, Vdst, Vsrc)
9498
end
9599
end
96100
end

0 commit comments

Comments
 (0)