@@ -12,6 +12,23 @@ namespace RI
1212
1313namespace Tensor_Multiply
1414{
15+ // Txy(x0,y0) = Tx(x0,a,b) * Ty(y0,a,b)
16+ template <typename Tdata>
17+ Tensor<Tdata> x0y0_x0ab_y0ab (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
18+ {
19+ assert (Tx.shape .size ()==3 );
20+ assert (Ty.shape .size ()==3 );
21+ Tensor<Tdata> Txy ({Tx.shape [0 ], Ty.shape [0 ]});
22+ Blas_Interface::gemm (
23+ ' N' , ' T' ,
24+ Tx.shape [0 ],
25+ Ty.shape [0 ],
26+ Tx.shape [1 ] * Tx.shape [2 ],
27+ Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
28+ Tdata (0.0 ), Txy.ptr ());
29+ return Txy;
30+ }
31+
1532 // Txy(x0,y2) = Tx(x0,a,b) * Ty(a,b,y2)
1633 template <typename Tdata>
1734 Tensor<Tdata> x0y2_x0ab_aby2 (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
@@ -63,6 +80,74 @@ namespace Tensor_Multiply
6380 return Txy;
6481 }
6582
83+ // Txy(x0,x1,y0,y1) = Tx(x0,x1,a) * Ty(y0,y1,a)
84+ template <typename Tdata>
85+ Tensor<Tdata> x0x1y0y1_x0x1a_y0y1a (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
86+ {
87+ assert (Tx.shape .size ()==3 );
88+ assert (Ty.shape .size ()==3 );
89+ Tensor<Tdata> Txy ({Tx.shape [0 ], Tx.shape [1 ], Ty.shape [0 ], Ty.shape [1 ]});
90+ Blas_Interface::gemm (
91+ ' N' , ' T' ,
92+ Tx.shape [0 ] * Tx.shape [1 ],
93+ Ty.shape [0 ] * Ty.shape [1 ],
94+ Tx.shape [2 ],
95+ Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
96+ Tdata (0.0 ), Txy.ptr ());
97+ return Txy;
98+ }
99+
100+ // Txy(x0,x1,y1,y2) = Tx(x0,x1,a) * Ty(a,y1,y2)
101+ template <typename Tdata>
102+ Tensor<Tdata> x0x1y1y2_x0x1a_ay1y2 (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
103+ {
104+ assert (Tx.shape .size ()==3 );
105+ assert (Ty.shape .size ()==3 );
106+ Tensor<Tdata> Txy ({Tx.shape [0 ], Tx.shape [1 ], Ty.shape [1 ], Ty.shape [2 ]});
107+ Blas_Interface::gemm (
108+ ' N' , ' N' ,
109+ Tx.shape [0 ] * Tx.shape [1 ],
110+ Ty.shape [1 ] * Ty.shape [2 ],
111+ Tx.shape [2 ],
112+ Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
113+ Tdata (0.0 ), Txy.ptr ());
114+ return Txy;
115+ }
116+
117+ // Txy(x1,x2,y0,y1) = Tx(a,x1,x2) * Ty(y0,y1,a)
118+ template <typename Tdata>
119+ Tensor<Tdata> x1x2y0y1_ax1x2_y0y1a (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
120+ {
121+ assert (Tx.shape .size ()==3 );
122+ assert (Ty.shape .size ()==3 );
123+ Tensor<Tdata> Txy ({Tx.shape [1 ], Tx.shape [2 ], Ty.shape [0 ], Ty.shape [1 ]});
124+ Blas_Interface::gemm (
125+ ' T' , ' T' ,
126+ Tx.shape [1 ] * Tx.shape [2 ],
127+ Ty.shape [0 ] * Ty.shape [1 ],
128+ Tx.shape [0 ],
129+ Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
130+ Tdata (0.0 ), Txy.ptr ());
131+ return Txy;
132+ }
133+
134+ // Txy(x1,x2,y1,y2) = Tx(a,x1,x2) * Ty(a,y1,y2)
135+ template <typename Tdata>
136+ Tensor<Tdata> x1x2y1y2_ax1x2_ay1y2 (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
137+ {
138+ assert (Tx.shape .size ()==3 );
139+ assert (Ty.shape .size ()==3 );
140+ Tensor<Tdata> Txy ({Tx.shape [1 ], Tx.shape [2 ], Ty.shape [1 ], Ty.shape [2 ]});
141+ Blas_Interface::gemm (
142+ ' T' , ' N' ,
143+ Tx.shape [1 ] * Tx.shape [2 ],
144+ Ty.shape [1 ] * Ty.shape [2 ],
145+ Tx.shape [0 ],
146+ Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
147+ Tdata (0.0 ), Txy.ptr ());
148+ return Txy;
149+ }
150+
66151}
67152
68153}
0 commit comments