Skip to content

Commit 161c617

Browse files
author
Vahid Tavanashad
committed
updates for onemkl interfaces
1 parent 41420f2 commit 161c617

File tree

3 files changed

+50
-25
lines changed

3 files changed

+50
-25
lines changed

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,14 @@ PYBIND11_MODULE(_blas_impl, m)
146146

147147
{
148148
m.def(
149-
"_row_major_is_available",
150-
[](void) {
151-
#if defined(USE_ONEMKL_CUBLAS)
152-
return false;
153-
#else
149+
"_using_onemkl_interfaces",
150+
[]() {
151+
#ifdef USE_ONEMKL_INTERFACES
154152
return true;
155-
#endif // USE_ONEMKL_CUBLAS
153+
#else
154+
return false;
155+
#endif
156156
},
157-
"Check if the onemkl::blas::row_major can be used.");
157+
"Check if the OneMKL interfaces are being used.");
158158
}
159159
}

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,28 @@ def _gemm_matmul(exec_q, x1, x2, res):
503503
return res
504504

505505

506+
def _gemm_special_case(x1, x2, res_dtype, call_flag):
507+
"""
508+
`gemm` and `gemm_batch` support these special cases of data types
509+
while `gemv` does not.
510+
511+
"""
512+
513+
# TODO: replace with dpnp.int8 when it is added
514+
x1_is_int8 = dpnp.issubdtype(x1.dtype, numpy.int8)
515+
x2_is_int8 = dpnp.issubdtype(x2.dtype, numpy.int8)
516+
res_is_int32 = dpnp.issubdtype(res_dtype, dpnp.int32)
517+
res_is_float32 = dpnp.issubdtype(res_dtype, dpnp.float32)
518+
519+
flag = x1_is_int8 and x2_is_int8 and (res_is_int32 or res_is_float32)
520+
flag = flag and call_flag in ["gemm", "gemm_batch"]
521+
522+
# onemkl_interfaces does not support these data types
523+
onemkl_interfaces = bi._using_onemkl_interfaces()
524+
525+
return flag and not onemkl_interfaces
526+
527+
506528
def _shape_error(shape1, shape2, func, err_msg):
507529
"""Validate the shapes of input and output arrays."""
508530

@@ -1008,21 +1030,18 @@ def dpnp_multiplication(
10081030
elif x1.size == 0 or x2.size == 0:
10091031
result.fill(0)
10101032
else:
1011-
# TODO: replace with dpnp.int8 when it is added
1012-
x1_is_int8 = dpnp.issubdtype(x1.dtype, numpy.int8)
1013-
x2_is_int8 = dpnp.issubdtype(x2.dtype, numpy.int8)
1014-
res_is_int32 = dpnp.issubdtype(res_dtype, dpnp.int32)
1015-
special_case = x1_is_int8 and x2_is_int8 and res_is_int32
1016-
special_case = special_case and call_flag == "gemm"
1017-
if special_case:
1018-
# OneMath supports this special case
1033+
if _gemm_special_case(x1, x2, res_dtype, call_flag):
10191034
x1 = _copy_array(
10201035
x1, copy_flag=not x1_contig_flag, order=res_order
10211036
)
10221037
x2 = _copy_array(
10231038
x2, copy_flag=not x2_contig_flag, order=res_order
10241039
)
1025-
result = _gemm_matmul(exec_q, x1, x2, result)
1040+
if call_flag == "gemm":
1041+
result = _gemm_matmul(exec_q, x1, x2, result)
1042+
else:
1043+
assert call_flag == "gemm_batch"
1044+
result = _gemm_batch_matmul(exec_q, x1, x2, result)
10261045
elif dpnp.issubdtype(res_dtype, dpnp.inexact):
10271046
# copying is needed if dtypes of input arrays are different or
10281047
# their base (last 2-dimensions) is not c-contiguous or f-contiguous

dpnp/tests/test_product.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,7 @@ def test_dtype_matrix(self, dt_in1, dt_in2, dt_out, shape1, shape2):
834834
# while NumPy give slightly different results. NumPy result obtained
835835
# with `dtype` is much closer to dpnp (a smaller `tol`) while the
836836
# result obtained with `out` needs a larger `tol` to match dpnp
837-
assert_allclose(result, expected, rtol=1e-6, atol=1e-6)
837+
assert_allclose(result, expected, rtol=1e-5, atol=1e-5)
838838
else:
839839
assert_raises(TypeError, dpnp.matmul, ia, ib, dtype=dt_out)
840840
assert_raises(TypeError, numpy.matmul, a, b, dtype=dt_out)
@@ -1151,23 +1151,29 @@ def test_large_values(self, dtype):
11511151
expected = numpy.matmul(a, b)
11521152
assert_dtype_allclose(result, expected)
11531153

1154-
def test_special_case(self):
1154+
@pytest.mark.parametrize("dt_out", [numpy.int32, numpy.float32])
1155+
@pytest.mark.parametrize(
1156+
"shape1, shape2",
1157+
[((2, 4), (4, 3)), ((4, 2, 3), (4, 3, 5))],
1158+
ids=["gemm", "gemm_batch"],
1159+
)
1160+
def test_special_case(self, dt_out, shape1, shape2):
11551161
# Although inputs are int, gemm will be used for calculation
1156-
a = numpy.ones((3, 4), dtype=numpy.int8)
1157-
b = numpy.ones((4, 5), dtype=numpy.int8)
1162+
a = numpy.ones(shape1, dtype=numpy.int8)
1163+
b = numpy.ones(shape2, dtype=numpy.int8)
11581164
ia, ib = dpnp.array(a), dpnp.array(b)
11591165

1160-
result = dpnp.matmul(ia, ib, dtype=numpy.int32)
1161-
expected = numpy.matmul(a, b, dtype=numpy.int32)
1166+
result = dpnp.matmul(ia, ib, dtype=dt_out)
1167+
expected = numpy.matmul(a, b, dtype=dt_out)
11621168
assert_dtype_allclose(result, expected)
11631169

1164-
iout = dpnp.empty((3, 5), dtype=numpy.int32)
1170+
iout = dpnp.empty(result.shape, dtype=dt_out)
11651171
result = dpnp.matmul(ia, ib, out=iout)
11661172
assert_dtype_allclose(result, expected)
11671173

11681174
def test_bool(self):
1169-
a = numpy.ones((3, 4), dtype=numpy.bool)
1170-
b = numpy.ones((4, 5), dtype=numpy.bool)
1175+
a = numpy.ones((3, 4), dtype=dpnp.bool)
1176+
b = numpy.ones((4, 5), dtype=dpnp.bool)
11711177
ia, ib = dpnp.array(a), dpnp.array(b)
11721178

11731179
# the output is (3, 4) array filled with 4

0 commit comments

Comments
 (0)