@@ -59,6 +59,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
59
59
const std::int64_t ,
60
60
char *,
61
61
const std::int64_t ,
62
+ bool ,
62
63
const std::vector<sycl::event> &);
63
64
64
65
static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
@@ -77,6 +78,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
77
78
const std::int64_t ldb,
78
79
char *resultC,
79
80
const std::int64_t ldc,
81
+ bool is_row_major,
80
82
const std::vector<sycl::event> &depends)
81
83
{
82
84
type_utils::validate_type_for_device<Tab>(exec_q);
@@ -91,7 +93,25 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
91
93
92
94
sycl::event gemm_event;
93
95
try {
94
- gemm_event = mkl_blas::row_major::gemm (
96
+ auto gemm_func =
97
+ [&](sycl::queue &q, oneapi::mkl::transpose transA,
98
+ oneapi::mkl::transpose transB, std::int64_t m, std::int64_t n,
99
+ std::int64_t k, Tab alpha, const Tab *a, std::int64_t lda,
100
+ const Tab *b, std::int64_t ldb, Tab beta, Tc *c,
101
+ std::int64_t ldc,
102
+ const std::vector<sycl::event> &deps) -> sycl::event {
103
+ if (is_row_major) {
104
+ return mkl_blas::row_major::gemm (q, transA, transB, m, n, k,
105
+ alpha, a, lda, b, ldb, beta, c,
106
+ ldc, deps);
107
+ }
108
+ else {
109
+ return mkl_blas::column_major::gemm (q, transA, transB, m, n, k,
110
+ alpha, a, lda, b, ldb, beta,
111
+ c, ldc, deps);
112
+ }
113
+ };
114
+ gemm_event = gemm_func (
95
115
exec_q,
96
116
transA, // Defines the transpose operation for matrix A:
97
117
// 'N' indicates no transpose, 'T' for transpose,
@@ -130,7 +150,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
130
150
return gemm_event;
131
151
}
132
152
133
- std::pair <sycl::event, sycl::event>
153
+ std::tuple <sycl::event, sycl::event, bool >
134
154
gemm (sycl::queue &exec_q,
135
155
dpctl::tensor::usm_ndarray matrixA,
136
156
dpctl::tensor::usm_ndarray matrixB,
@@ -208,16 +228,44 @@ std::pair<sycl::event, sycl::event>
208
228
throw py::value_error (
209
229
" Result array is not c-contiguous nor f-contiguous." );
210
230
}
211
- oneapi::mkl::transpose transA = is_matrixA_f_contig
212
- ? oneapi::mkl::transpose::T
213
- : oneapi::mkl::transpose::N;
214
- oneapi::mkl::transpose transB = is_matrixB_f_contig
215
- ? oneapi::mkl::transpose::T
216
- : oneapi::mkl::transpose::N;
231
+ bool is_row_major = true ;
232
+ if (is_matrixA_f_contig && is_matrixB_f_contig) {
233
+ is_row_major = false ;
234
+ }
235
+ oneapi::mkl::transpose transA;
236
+ oneapi::mkl::transpose transB;
237
+ if (is_row_major) {
238
+ transA = is_matrixA_f_contig ? oneapi::mkl::transpose::T
239
+ : oneapi::mkl::transpose::N;
240
+ transB = is_matrixB_f_contig ? oneapi::mkl::transpose::T
241
+ : oneapi::mkl::transpose::N;
242
+ }
243
+ else {
244
+ transA = oneapi::mkl::transpose::N;
245
+ transB = oneapi::mkl::transpose::N;
246
+ }
217
247
218
- const std::int64_t lda = (transA == oneapi::mkl::transpose::N) ? k : m;
219
- const std::int64_t ldb = (transB == oneapi::mkl::transpose::N) ? n : k;
220
- const std::int64_t ldc = n; // always n for row_major
248
+ std::int64_t lda;
249
+ std::int64_t ldb;
250
+ if (is_row_major) {
251
+ if (transA == oneapi::mkl::transpose::N) {
252
+ lda = k;
253
+ }
254
+ else {
255
+ lda = m;
256
+ }
257
+ if (transB == oneapi::mkl::transpose::N) {
258
+ ldb = n;
259
+ }
260
+ else {
261
+ ldb = k;
262
+ }
263
+ }
264
+ else {
265
+ lda = m;
266
+ ldb = k;
267
+ }
268
+ const std::int64_t ldc = is_row_major ? n : m;
221
269
222
270
int matrixA_typenum = matrixA.get_typenum ();
223
271
int matrixB_typenum = matrixB.get_typenum ();
@@ -242,14 +290,14 @@ std::pair<sycl::event, sycl::event>
242
290
char *b_typeless_ptr = matrixB.get_data ();
243
291
char *r_typeless_ptr = resultC.get_data ();
244
292
245
- sycl::event gemm_ev =
246
- gemm_fn (exec_q, transA, transB, m, n, k, a_typeless_ptr, lda,
247
- b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends);
293
+ sycl::event gemm_ev = gemm_fn (exec_q, transA, transB, m, n, k,
294
+ a_typeless_ptr, lda, b_typeless_ptr, ldb ,
295
+ r_typeless_ptr, ldc, is_row_major , depends);
248
296
249
297
sycl::event args_ev = dpctl::utils::keep_args_alive (
250
298
exec_q, {matrixA, matrixB, resultC}, {gemm_ev});
251
299
252
- return std::make_pair (args_ev, gemm_ev);
300
+ return std::make_tuple (args_ev, gemm_ev, is_row_major );
253
301
}
254
302
255
303
template <typename fnT, typename Tab, typename Tc>
0 commit comments