7373function treebraider (:: AbstractTensorMap , :: AbstractTensorMap , p:: Index2Tuple , levels)
7474 return fusiontreetransform (f1, f2) = braid (f1, f2, levels... , p... )
7575end
76- @cached function treebraider (tdst:: TensorMap , tsrc:: TensorMap , p:: Index2Tuple ,
77- levels):: treetransformertype (space (tdst),
78- space (tsrc))
76+ function treebraider (tdst:: TensorMap , tsrc:: TensorMap , p:: Index2Tuple , levels)
77+ return treebraider (space (tdst), space (tsrc), p, levels)
78+ end
79+ @cached function treebraider (Vdst:: TensorMapSpace , Vsrc:: TensorMapSpace , p:: Index2Tuple ,
80+ levels):: treetransformertype (Vdst, Vsrc)
7981 fusiontreebraider (f1, f2) = braid (f1, f2, levels... , p... )
80- return TreeTransformer (fusiontreebraider, space (tsrc), space (tdst) )
82+ return TreeTransformer (fusiontreebraider, Vdst, Vsrc )
8183end
8284
8385for (transform, treetransformer) in
@@ -86,11 +88,13 @@ for (transform, treetransformer) in
8688 function $treetransformer (:: AbstractTensorMap , :: AbstractTensorMap , p:: Index2Tuple )
8789 return fusiontreetransform (f1, f2) = $ transform (f1, f2, p... )
8890 end
89- @cached function $treetransformer (tdst:: TensorMap , tsrc:: TensorMap ,
90- p:: Index2Tuple ):: treetransformertype (space (tdst),
91- space (tsrc))
91+ function $treetransformer (tdst:: TensorMap , tsrc:: TensorMap , p:: Index2Tuple )
92+ return $ treetransformer (space (tdst), space (tsrc), p)
93+ end
94+ @cached function $treetransformer (Vdst:: TensorMapSpace , Vsrc:: TensorMapSpace ,
95+ p:: Index2Tuple ):: treetransformertype (Vdst, Vsrc)
9296 fusiontreetransform (f1, f2) = $ transform (f1, f2, p... )
93- return TreeTransformer (fusiontreetransform, space (tsrc), space (tdst) )
97+ return TreeTransformer (fusiontreetransform, Vdst, Vsrc )
9498 end
9599 end
96100end
0 commit comments