1+ // ===================
2+ // Author: Peize Lin
3+ // date: 2023.08.03
4+ // ===================
5+
6+ #pragma once
7+
8+ #include " Tensor_Multiply.h"
9+
10+ namespace RI
11+ {
12+
13+ namespace Tensor_Multiply
14+ {
15+ // Txy(x1,x2,y1) = Tx(a,x1,x2) * Ty(a,y1)
16+ template <typename Tdata>
17+ Tensor<Tdata> x1x2y1_x0y0 (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
18+ {
19+ assert (Tx.shape .size ()==3 );
20+ assert (Ty.shape .size ()==2 );
21+ assert (Tx.shape [0 ]==Ty.shape [0 ]);
22+ const std::size_t x12 = Tx.shape [1 ] * Tx.shape [2 ];
23+ Tensor<Tdata> Txy ({Tx.shape [1 ], Tx.shape [2 ], Ty.shape [1 ]});
24+ Blas_Interface::gemm (
25+ ' T' , ' N' , x12, Ty.shape [1 ], Tx.shape [0 ],
26+ Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
27+ Tdata (0.0 ), Txy.ptr ());
28+ return Txy;
29+ }
30+
31+ // Txy(x0,x1,y0) = Tx(x0,x1,a) * Ty(y0,a)
32+ template <typename Tdata>
33+ Tensor<Tdata> x0x1y0_x2y1 (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
34+ {
35+ assert (Tx.shape .size ()==3 );
36+ assert (Ty.shape .size ()==2 );
37+ assert (Tx.shape [2 ]==Ty.shape [1 ]);
38+ const std::size_t x01 = Tx.shape [0 ] * Tx.shape [1 ];
39+ Tensor<Tdata> Txy ({Tx.shape [0 ], Tx.shape [1 ], Ty.shape [0 ]});
40+ Blas_Interface::gemm (
41+ ' N' , ' T' , x01, Ty.shape [0 ], Tx.shape [2 ],
42+ Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
43+ Tdata (0.0 ), Txy.ptr ());
44+ return Txy;
45+ }
46+
47+ // Txy(x1,x2,y0) = Tx(a,x1,x2) * Ty(y0,a)
48+ template <typename Tdata>
49+ Tensor<Tdata> x1x2y0_x0y1 (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
50+ {
51+ assert (Tx.shape .size ()==3 );
52+ assert (Ty.shape .size ()==2 );
53+ const std::size_t x12 = Tx.shape [1 ] * Tx.shape [2 ];
54+ Tensor<Tdata> Txy ({Tx.shape [1 ], Tx.shape [2 ], Ty.shape [0 ]});
55+ Blas_Interface::gemm (
56+ ' T' , ' T' , x12, Ty.shape [0 ], Tx.shape [0 ],
57+ Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
58+ Tdata (0.0 ), Txy.ptr ());
59+ return Txy;
60+ }
61+
62+ // Txy(x0,x1,y1) = Tx(x0,x1,a) * Ty(a,y1)
63+ template <typename Tdata>
64+ Tensor<Tdata> x0x1y1_x2y0 (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
65+ {
66+ assert (Tx.shape .size ()==3 );
67+ assert (Ty.shape .size ()==2 );
68+ const std::size_t x01 = Tx.shape [0 ] * Tx.shape [1 ];
69+ Tensor<Tdata> Txy ({Tx.shape [0 ], Tx.shape [1 ], Ty.shape [1 ]});
70+ Blas_Interface::gemm (
71+ ' N' , ' N' , x01, Ty.shape [1 ], Tx.shape [2 ],
72+ Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
73+ Tdata (0.0 ), Txy.ptr ());
74+ return Txy;
75+ }
76+ }
77+
78+ }
0 commit comments