1- // ===================
2- // Author: Peize Lin
3- // date: 2023.08.03
4- // ===================
5-
6- #pragma once
7-
8- #include " Tensor_Multiply.h"
9-
10- namespace RI
11- {
12-
13- namespace Tensor_Multiply
14- {
15- // Txy(x1,x2,y1) = Tx(a,x1,x2) * Ty(a,y1)
16- template <typename Tdata>
17- Tensor<Tdata> x1x2y1_x0y0 (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
18- {
19- assert (Tx.shape .size ()==3 );
20- assert (Ty.shape .size ()==2 );
21- assert (Tx.shape [0 ]==Ty.shape [0 ]);
22- const std::size_t x12 = Tx.shape [1 ] * Tx.shape [2 ];
23- Tensor<Tdata> Txy ({Tx.shape [1 ], Tx.shape [2 ], Ty.shape [1 ]});
24- Blas_Interface::gemm (
25- ' T' , ' N' , x12, Ty.shape [1 ], Tx.shape [0 ],
26- Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
27- Tdata (0.0 ), Txy.ptr ());
28- return Txy;
29- }
30-
31- // Txy(x0,x1,y0) = Tx(x0,x1,a) * Ty(y0,a)
32- template <typename Tdata>
33- Tensor<Tdata> x0x1y0_x2y1 (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
34- {
35- assert (Tx.shape .size ()==3 );
36- assert (Ty.shape .size ()==2 );
37- assert (Tx.shape [2 ]==Ty.shape [1 ]);
38- const std::size_t x01 = Tx.shape [0 ] * Tx.shape [1 ];
39- Tensor<Tdata> Txy ({Tx.shape [0 ], Tx.shape [1 ], Ty.shape [0 ]});
40- Blas_Interface::gemm (
41- ' N' , ' T' , x01, Ty.shape [0 ], Tx.shape [2 ],
42- Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
43- Tdata (0.0 ), Txy.ptr ());
44- return Txy;
45- }
46-
47- // Txy(x1,x2,y0) = Tx(a,x1,x2) * Ty(y0,a)
48- template <typename Tdata>
49- Tensor<Tdata> x1x2y0_x0y1 (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
50- {
51- assert (Tx.shape .size ()==3 );
52- assert (Ty.shape .size ()==2 );
53- const std::size_t x12 = Tx.shape [1 ] * Tx.shape [2 ];
54- Tensor<Tdata> Txy ({Tx.shape [1 ], Tx.shape [2 ], Ty.shape [0 ]});
55- Blas_Interface::gemm (
56- ' T' , ' T' , x12, Ty.shape [0 ], Tx.shape [0 ],
57- Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
58- Tdata (0.0 ), Txy.ptr ());
59- return Txy;
60- }
61-
62- // Txy(x0,x1,y1) = Tx(x0,x1,a) * Ty(a,y1)
63- template <typename Tdata>
64- Tensor<Tdata> x0x1y1_x2y0 (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
65- {
66- assert (Tx.shape .size ()==3 );
67- assert (Ty.shape .size ()==2 );
68- const std::size_t x01 = Tx.shape [0 ] * Tx.shape [1 ];
69- Tensor<Tdata> Txy ({Tx.shape [0 ], Tx.shape [1 ], Ty.shape [1 ]});
70- Blas_Interface::gemm (
71- ' N' , ' N' , x01, Ty.shape [1 ], Tx.shape [2 ],
72- Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
73- Tdata (0.0 ), Txy.ptr ());
74- return Txy;
75- }
76- }
77-
78- }
1+ // ===================
2+ // Author: Peize Lin
3+ // date: 2023.08.03
4+ // ===================
5+
6+ #pragma once
7+
8+ #include " Tensor_Multiply.h"
9+
10+ namespace RI
11+ {
12+
13+ namespace Tensor_Multiply
14+ {
15+ // Txy(x1,x2,y1) = Tx(a,x1,x2) * Ty(a,y1)
16+ template <typename Tdata>
17+ Tensor<Tdata> x1x2y1_ax1x2_ay1 (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
18+ {
19+ assert (Tx.shape .size ()==3 );
20+ assert (Ty.shape .size ()==2 );
21+ Tensor<Tdata> Txy ({Tx.shape [1 ], Tx.shape [2 ], Ty.shape [1 ]});
22+ Blas_Interface::gemm (
23+ ' T' , ' N' ,
24+ Tx.shape [1 ] * Tx.shape [2 ],
25+ Ty.shape [1 ],
26+ Tx.shape [0 ],
27+ Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
28+ Tdata (0.0 ), Txy.ptr ());
29+ return Txy;
30+ }
31+
32+ // Txy(x0,x1,y0) = Tx(x0,x1,a) * Ty(y0,a)
33+ template <typename Tdata>
34+ Tensor<Tdata> x0x1y0_x0x1a_y0a (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
35+ {
36+ assert (Tx.shape .size ()==3 );
37+ assert (Ty.shape .size ()==2 );
38+ Tensor<Tdata> Txy ({Tx.shape [0 ], Tx.shape [1 ], Ty.shape [0 ]});
39+ Blas_Interface::gemm (
40+ ' N' , ' T' ,
41+ Tx.shape [0 ] * Tx.shape [1 ],
42+ Ty.shape [0 ],
43+ Tx.shape [2 ],
44+ Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
45+ Tdata (0.0 ), Txy.ptr ());
46+ return Txy;
47+ }
48+
49+ // Txy(x1,x2,y0) = Tx(a,x1,x2) * Ty(y0,a)
50+ template <typename Tdata>
51+ Tensor<Tdata> x1x2y0_ax1x2_y0a (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
52+ {
53+ assert (Tx.shape .size ()==3 );
54+ assert (Ty.shape .size ()==2 );
55+ Tensor<Tdata> Txy ({Tx.shape [1 ], Tx.shape [2 ], Ty.shape [0 ]});
56+ Blas_Interface::gemm (
57+ ' T' , ' T' ,
58+ Tx.shape [1 ] * Tx.shape [2 ],
59+ Ty.shape [0 ],
60+ Tx.shape [0 ],
61+ Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
62+ Tdata (0.0 ), Txy.ptr ());
63+ return Txy;
64+ }
65+
66+ // Txy(x0,x1,y1) = Tx(x0,x1,a) * Ty(a,y1)
67+ template <typename Tdata>
68+ Tensor<Tdata> x0x1y1_x0x1a_ay1 (const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
69+ {
70+ assert (Tx.shape .size ()==3 );
71+ assert (Ty.shape .size ()==2 );
72+ Tensor<Tdata> Txy ({Tx.shape [0 ], Tx.shape [1 ], Ty.shape [1 ]});
73+ Blas_Interface::gemm (
74+ ' N' , ' N' ,
75+ Tx.shape [0 ] * Tx.shape [1 ],
76+ Ty.shape [1 ],
77+ Tx.shape [2 ],
78+ Tdata (1.0 ), Tx.ptr (), Ty.ptr (),
79+ Tdata (0.0 ), Txy.ptr ());
80+ return Txy;
81+ }
82+
83+ }
84+
85+ }
0 commit comments