Skip to content

Commit 4c4b5c5

Browse files
vtavanaantonwolfy
andauthored
update OneMKL gemm_batch call inside dpnp.matmul and column_major version of gemm (#1793)
* updating gemm_batch call * fix failed tests for dot function * address comments * split batch_size and column major split batch_size to smaller chunks and gemm with column major when both input array F-contig * fix failed tests * simplify gemm implementation * waiting for host task before changing results * address comments * reorder functions in dpnp/dpnp_utils/dpnp_utils_linearalgebra.py to keep their alphabetic order --------- Co-authored-by: Anton <[email protected]>
1 parent 9f072b6 commit 4c4b5c5

File tree

6 files changed

+589
-397
lines changed

6 files changed

+589
-397
lines changed

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ PYBIND11_MODULE(_blas_impl, m)
6464
blas_ext::DotContigFactory>(
6565
dot_dispatch_vector);
6666

67-
auto dot_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
68-
arrayT dst, const event_vecT &depends = {}) {
67+
auto dot_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
68+
arrayT dst, const event_vecT &depends = {}) {
6969
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
7070
dot_dispatch_vector);
7171
};
7272

73-
m.def("_dot", dot_pypi,
73+
m.def("_dot", dot_pyapi,
7474
"Call `dot` from OneMKL BLAS library to return "
7575
"the dot product of two real-valued vectors.",
7676
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
@@ -82,13 +82,13 @@ PYBIND11_MODULE(_blas_impl, m)
8282
blas_ext::DotcContigFactory>(
8383
dotc_dispatch_vector);
8484

85-
auto dotc_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
86-
arrayT dst, const event_vecT &depends = {}) {
85+
auto dotc_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
86+
arrayT dst, const event_vecT &depends = {}) {
8787
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
8888
dotc_dispatch_vector);
8989
};
9090

91-
m.def("_dotc", dotc_pypi,
91+
m.def("_dotc", dotc_pyapi,
9292
"Call `dotc` from OneMKL BLAS library to return "
9393
"the dot product of two complex vectors, "
9494
"conjugating the first vector.",
@@ -101,13 +101,13 @@ PYBIND11_MODULE(_blas_impl, m)
101101
blas_ext::DotuContigFactory>(
102102
dotu_dispatch_vector);
103103

104-
auto dotu_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
105-
arrayT dst, const event_vecT &depends = {}) {
104+
auto dotu_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
105+
arrayT dst, const event_vecT &depends = {}) {
106106
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
107107
dotu_dispatch_vector);
108108
};
109109

110-
m.def("_dotu", dotu_pypi,
110+
m.def("_dotu", dotu_pyapi,
111111
"Call `dotu` from OneMKL BLAS library to return "
112112
"the dot product of two complex vectors.",
113113
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
@@ -119,16 +119,14 @@ PYBIND11_MODULE(_blas_impl, m)
119119
"Call `gemm` from OneMKL BLAS library to return "
120120
"the matrix-matrix product with 2-D matrices.",
121121
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
122-
py::arg("result"), py::arg("depends") = py::list());
122+
py::arg("resultC"), py::arg("depends") = py::list());
123123
}
124124

125125
{
126126
m.def("_gemm_batch", &blas_ext::gemm_batch,
127127
"Call `gemm_batch` from OneMKL BLAS library to return "
128128
"the matrix-matrix product for a batch of 2-D matrices.",
129129
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
130-
py::arg("result"), py::arg("batch_size"), py::arg("stridea"),
131-
py::arg("strideb"), py::arg("stridec"),
132-
py::arg("depends") = py::list());
130+
py::arg("resultC"), py::arg("depends") = py::list());
133131
}
134132
}

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
5959
const std::int64_t,
6060
char *,
6161
const std::int64_t,
62+
bool,
6263
const std::vector<sycl::event> &);
6364

6465
static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
@@ -77,6 +78,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
7778
const std::int64_t ldb,
7879
char *resultC,
7980
const std::int64_t ldc,
81+
bool is_row_major,
8082
const std::vector<sycl::event> &depends)
8183
{
8284
type_utils::validate_type_for_device<Tab>(exec_q);
@@ -91,7 +93,25 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
9193

9294
sycl::event gemm_event;
9395
try {
94-
gemm_event = mkl_blas::row_major::gemm(
96+
auto gemm_func =
97+
[&](sycl::queue &q, oneapi::mkl::transpose transA,
98+
oneapi::mkl::transpose transB, std::int64_t m, std::int64_t n,
99+
std::int64_t k, Tab alpha, const Tab *a, std::int64_t lda,
100+
const Tab *b, std::int64_t ldb, Tab beta, Tc *c,
101+
std::int64_t ldc,
102+
const std::vector<sycl::event> &deps) -> sycl::event {
103+
if (is_row_major) {
104+
return mkl_blas::row_major::gemm(q, transA, transB, m, n, k,
105+
alpha, a, lda, b, ldb, beta, c,
106+
ldc, deps);
107+
}
108+
else {
109+
return mkl_blas::column_major::gemm(q, transA, transB, m, n, k,
110+
alpha, a, lda, b, ldb, beta,
111+
c, ldc, deps);
112+
}
113+
};
114+
gemm_event = gemm_func(
95115
exec_q,
96116
transA, // Defines the transpose operation for matrix A:
97117
// 'N' indicates no transpose, 'T' for transpose,
@@ -130,7 +150,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
130150
return gemm_event;
131151
}
132152

133-
std::pair<sycl::event, sycl::event>
153+
std::tuple<sycl::event, sycl::event, bool>
134154
gemm(sycl::queue &exec_q,
135155
dpctl::tensor::usm_ndarray matrixA,
136156
dpctl::tensor::usm_ndarray matrixB,
@@ -208,16 +228,44 @@ std::pair<sycl::event, sycl::event>
208228
throw py::value_error(
209229
"Result array is not c-contiguous nor f-contiguous.");
210230
}
211-
oneapi::mkl::transpose transA = is_matrixA_f_contig
212-
? oneapi::mkl::transpose::T
213-
: oneapi::mkl::transpose::N;
214-
oneapi::mkl::transpose transB = is_matrixB_f_contig
215-
? oneapi::mkl::transpose::T
216-
: oneapi::mkl::transpose::N;
231+
bool is_row_major = true;
232+
if (is_matrixA_f_contig && is_matrixB_f_contig) {
233+
is_row_major = false;
234+
}
235+
oneapi::mkl::transpose transA;
236+
oneapi::mkl::transpose transB;
237+
if (is_row_major) {
238+
transA = is_matrixA_f_contig ? oneapi::mkl::transpose::T
239+
: oneapi::mkl::transpose::N;
240+
transB = is_matrixB_f_contig ? oneapi::mkl::transpose::T
241+
: oneapi::mkl::transpose::N;
242+
}
243+
else {
244+
transA = oneapi::mkl::transpose::N;
245+
transB = oneapi::mkl::transpose::N;
246+
}
217247

218-
const std::int64_t lda = (transA == oneapi::mkl::transpose::N) ? k : m;
219-
const std::int64_t ldb = (transB == oneapi::mkl::transpose::N) ? n : k;
220-
const std::int64_t ldc = n; // always n for row_major
248+
std::int64_t lda;
249+
std::int64_t ldb;
250+
if (is_row_major) {
251+
if (transA == oneapi::mkl::transpose::N) {
252+
lda = k;
253+
}
254+
else {
255+
lda = m;
256+
}
257+
if (transB == oneapi::mkl::transpose::N) {
258+
ldb = n;
259+
}
260+
else {
261+
ldb = k;
262+
}
263+
}
264+
else {
265+
lda = m;
266+
ldb = k;
267+
}
268+
const std::int64_t ldc = is_row_major ? n : m;
221269

222270
int matrixA_typenum = matrixA.get_typenum();
223271
int matrixB_typenum = matrixB.get_typenum();
@@ -242,14 +290,14 @@ std::pair<sycl::event, sycl::event>
242290
char *b_typeless_ptr = matrixB.get_data();
243291
char *r_typeless_ptr = resultC.get_data();
244292

245-
sycl::event gemm_ev =
246-
gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda,
247-
b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends);
293+
sycl::event gemm_ev = gemm_fn(exec_q, transA, transB, m, n, k,
294+
a_typeless_ptr, lda, b_typeless_ptr, ldb,
295+
r_typeless_ptr, ldc, is_row_major, depends);
248296

249297
sycl::event args_ev = dpctl::utils::keep_args_alive(
250298
exec_q, {matrixA, matrixB, resultC}, {gemm_ev});
251299

252-
return std::make_pair(args_ev, gemm_ev);
300+
return std::make_tuple(args_ev, gemm_ev, is_row_major);
253301
}
254302

255303
template <typename fnT, typename Tab, typename Tc>

dpnp/backend/extensions/blas/gemm.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,18 @@ namespace ext
3838
{
3939
namespace blas
4040
{
41-
extern std::pair<sycl::event, sycl::event>
41+
extern std::tuple<sycl::event, sycl::event, bool>
4242
gemm(sycl::queue &exec_q,
4343
dpctl::tensor::usm_ndarray matrixA,
4444
dpctl::tensor::usm_ndarray matrixB,
4545
dpctl::tensor::usm_ndarray resultC,
4646
const std::vector<sycl::event> &depends);
4747

48-
extern std::pair<sycl::event, sycl::event>
48+
extern std::tuple<sycl::event, sycl::event, bool>
4949
gemm_batch(sycl::queue &exec_q,
5050
dpctl::tensor::usm_ndarray matrixA,
5151
dpctl::tensor::usm_ndarray matrixB,
5252
dpctl::tensor::usm_ndarray resultC,
53-
const std::int64_t batch_size,
54-
size_t stridea,
55-
size_t strideb,
56-
size_t stridec,
5753
const std::vector<sycl::event> &depends);
5854

5955
extern void init_gemm_dispatch_table(void);

0 commit comments

Comments
 (0)