@@ -54,10 +54,15 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
5454 static_cast <int >(gemm_memory_t ::no_local),
5555 static_cast <int >(gemm_algorithm_t ::standard),
5656 static_cast <int >(gemm_vectorization_t ::full), is_beta_zero, 4 ,
57- static_cast <int >(gemm_batch_type_t ::interleaved)>::
58- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
59- _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
60- _stridec, batch_size, _dependencies);
57+ static_cast <int >(
58+ gemm_batch_type_t ::interleaved)>::_select_gemm (sb_handle, _M, _N,
59+ _K, _alpha, _a,
60+ _lda, _stridea, _b,
61+ _ldb, _strideb,
62+ _beta, _c, _ldc,
63+ _stridec,
64+ batch_size,
65+ _dependencies);
6166 }
6267#if defined(NAIVE_GEMM)
6368 return blas::Gemm_Launcher<
@@ -66,10 +71,14 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
6671 static_cast <int >(gemm_memory_t ::no_local),
6772 static_cast <int >(gemm_algorithm_t ::naive),
6873 static_cast <int >(gemm_vectorization_t ::partial), is_beta_zero, 1 ,
69- static_cast <int >(gemm_batch_type_t ::strided)>::
70- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
71- _b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
72- batch_size, _dependencies);
74+ static_cast <int >(
75+ gemm_batch_type_t ::strided)>::_select_gemm (sb_handle, _M, _N, _K,
76+ _alpha, _a, _lda,
77+ _stridea, _b, _ldb,
78+ _strideb, _beta, _c,
79+ _ldc, _stridec,
80+ batch_size,
81+ _dependencies);
7382#else
7483 if (_M <= 128 && _N <= 128 && _K <= 256 && !s_a && !s_b) {
7584 return blas::Gemm_Launcher<
@@ -78,43 +87,59 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
7887 static_cast <int >(gemm_memory_t ::no_local),
7988 static_cast <int >(gemm_algorithm_t ::standard),
8089 static_cast <int >(gemm_vectorization_t ::full), is_beta_zero, 2 ,
81- static_cast <int >(gemm_batch_type_t ::strided)>::
82- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
83- _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
84- _stridec, batch_size, _dependencies);
90+ static_cast <int >(
91+ gemm_batch_type_t ::strided)>::_select_gemm (sb_handle, _M, _N, _K,
92+ _alpha, _a, _lda,
93+ _stridea, _b, _ldb,
94+ _strideb, _beta, _c,
95+ _ldc, _stridec,
96+ batch_size,
97+ _dependencies);
8598 } else if ((_M * _N) >= 524288 && !s_a && !s_b) {
8699 return blas::Gemm_Launcher<
87100 container_0_t , container_1_t , container_2_t , 128 , false , false , false ,
88101 64 , Tile<4 , 4 , 4 , 4 >, _t_a, _t_b, s_a, s_b,
89102 static_cast <int >(gemm_memory_t ::no_local),
90103 static_cast <int >(gemm_algorithm_t ::standard),
91104 static_cast <int >(gemm_vectorization_t ::partial), is_beta_zero, 1 ,
92- static_cast <int >(gemm_batch_type_t ::strided)>::
93- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
94- _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
95- _stridec, batch_size, _dependencies);
105+ static_cast <int >(
106+ gemm_batch_type_t ::strided)>::_select_gemm (sb_handle, _M, _N, _K,
107+ _alpha, _a, _lda,
108+ _stridea, _b, _ldb,
109+ _strideb, _beta, _c,
110+ _ldc, _stridec,
111+ batch_size,
112+ _dependencies);
96113 } else if (!s_a && !s_b) {
97114 return blas::Gemm_Launcher<
98115 container_0_t , container_1_t , container_2_t , 128 , false , false , false ,
99116 64 , Tile<4 , 4 , 8 , 8 >, _t_a, _t_b, s_a, s_b,
100117 static_cast <int >(gemm_memory_t ::no_local),
101118 static_cast <int >(gemm_algorithm_t ::standard),
102119 static_cast <int >(gemm_vectorization_t ::full), is_beta_zero, 1 ,
103- static_cast <int >(gemm_batch_type_t ::strided)>::
104- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
105- _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
106- _stridec, batch_size, _dependencies);
120+ static_cast <int >(
121+ gemm_batch_type_t ::strided)>::_select_gemm (sb_handle, _M, _N, _K,
122+ _alpha, _a, _lda,
123+ _stridea, _b, _ldb,
124+ _strideb, _beta, _c,
125+ _ldc, _stridec,
126+ batch_size,
127+ _dependencies);
107128 } else {
108129 return blas::Gemm_Launcher<
109130 container_0_t , container_1_t , container_2_t , 64 , false , false , false ,
110131 64 , Tile<2 , 2 , 8 , 8 >, _t_a, _t_b, s_a, s_b,
111132 static_cast <int >(gemm_memory_t ::local),
112133 static_cast <int >(gemm_algorithm_t ::standard),
113134 static_cast <int >(gemm_vectorization_t ::full), is_beta_zero, 2 ,
114- static_cast <int >(gemm_batch_type_t ::strided)>::
115- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
116- _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
117- _stridec, batch_size, _dependencies);
135+ static_cast <int >(
136+ gemm_batch_type_t ::strided)>::_select_gemm (sb_handle, _M, _N, _K,
137+ _alpha, _a, _lda,
138+ _stridea, _b, _ldb,
139+ _strideb, _beta, _c,
140+ _ldc, _stridec,
141+ batch_size,
142+ _dependencies);
118143 }
119144
120145#endif
@@ -145,10 +170,15 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
145170 static_cast <int >(gemm_memory_t ::no_local),
146171 static_cast <int >(gemm_algorithm_t ::standard),
147172 static_cast <int >(gemm_vectorization_t ::full), is_beta_zero, 4 ,
148- static_cast <int >(gemm_batch_type_t ::interleaved)>::
149- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda,
150- _stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
151- _stridec, batch_size, _dependencies);
173+ static_cast <int >(
174+ gemm_batch_type_t ::interleaved)>::_select_gemm (sb_handle, _M, _N,
175+ _K, _alpha, _a,
176+ _lda, _stridea, _b,
177+ _ldb, _strideb,
178+ _beta, _c, _ldc,
179+ _stridec,
180+ batch_size,
181+ _dependencies);
152182 }
153183
154184 return blas::Gemm_Launcher<
@@ -157,10 +187,14 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
157187 static_cast <int >(gemm_memory_t ::no_local),
158188 static_cast <int >(gemm_algorithm_t ::standard),
159189 static_cast <int >(gemm_vectorization_t ::full), is_beta_zero, 1 ,
160- static_cast <int >(gemm_batch_type_t ::strided)>::
161- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
162- _b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
163- batch_size, _dependencies);
190+ static_cast <int >(
191+ gemm_batch_type_t ::strided)>::_select_gemm (sb_handle, _M, _N, _K,
192+ _alpha, _a, _lda,
193+ _stridea, _b, _ldb,
194+ _strideb, _beta, _c,
195+ _ldc, _stridec,
196+ batch_size,
197+ _dependencies);
164198 }
165199}
166200
@@ -184,21 +218,29 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
184218 static_cast <int >(gemm_memory_t ::no_local),
185219 static_cast <int >(gemm_algorithm_t ::standard),
186220 static_cast <int >(gemm_vectorization_t ::full), is_beta_zero, 1 ,
187- static_cast <int >(gemm_batch_type_t ::strided)>::
188- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
189- _b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
190- batch_size, _dependencies);
221+ static_cast <int >(
222+ gemm_batch_type_t ::strided)>::_select_gemm (sb_handle, _M, _N, _K,
223+ _alpha, _a, _lda,
224+ _stridea, _b, _ldb,
225+ _strideb, _beta, _c,
226+ _ldc, _stridec,
227+ batch_size,
228+ _dependencies);
191229 } else {
192230 return blas::Gemm_Launcher<
193231 container_0_t , container_1_t , container_2_t , 64 , false , false , false ,
194232 64 , Tile<8 , 8 , 4 , 4 >, _t_a, _t_b, false , false ,
195233 static_cast <int >(gemm_memory_t ::no_local),
196234 static_cast <int >(gemm_algorithm_t ::standard),
197235 static_cast <int >(gemm_vectorization_t ::partial), is_beta_zero, 1 ,
198- static_cast <int >(gemm_batch_type_t ::strided)>::
199- template _select_gemm (sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
200- _b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
201- batch_size, _dependencies);
236+ static_cast <int >(
237+ gemm_batch_type_t ::strided)>::_select_gemm (sb_handle, _M, _N, _K,
238+ _alpha, _a, _lda,
239+ _stridea, _b, _ldb,
240+ _strideb, _beta, _c,
241+ _ldc, _stridec,
242+ batch_size,
243+ _dependencies);
202244 }
203245}
204246#endif
0 commit comments