Skip to content

Commit 6cc6dee

Browse files
author
Vahid Tavanashad
committed
address new comments
1 parent 161c617 commit 6cc6dee

File tree

2 files changed

+14
-19
lines changed

2 files changed

+14
-19
lines changed

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, out=None, casting="no"):
8989
# are cast to out dtype and then calculation is performed. Even when inputs
9090
# are boolean and `dtype` is given, the casting is done first and then the
9191
# calculation is performed.
92-
if out is not None and not dpnp.issubdtype(res_dtype, dpnp.bool):
92+
if out is not None and res_dtype != dpnp.bool:
9393
# out dtype is prioritized over a given dtype
9494
dtype = out.dtype
9595

@@ -509,15 +509,10 @@ def _gemm_special_case(x1, x2, res_dtype, call_flag):
509509
while `gemv` does not.
510510
511511
"""
512-
513512
# TODO: replace with dpnp.int8 when it is added
514-
x1_is_int8 = dpnp.issubdtype(x1.dtype, numpy.int8)
515-
x2_is_int8 = dpnp.issubdtype(x2.dtype, numpy.int8)
516-
res_is_int32 = dpnp.issubdtype(res_dtype, dpnp.int32)
517-
res_is_float32 = dpnp.issubdtype(res_dtype, dpnp.float32)
518-
519-
flag = x1_is_int8 and x2_is_int8 and (res_is_int32 or res_is_float32)
520-
flag = flag and call_flag in ["gemm", "gemm_batch"]
513+
is_int8 = x1.dtype == numpy.int8 and x2.dtype == numpy.int8
514+
is_int32_or_f32 = res_dtype in [dpnp.int32, dpnp.float32]
515+
flag = is_int8 and is_int32_or_f32 and call_flag in ["gemm", "gemm_batch"]
521516

522517
# onemkl_interfaces does not support these data types
523518
onemkl_interfaces = bi._using_onemkl_interfaces()
@@ -1084,7 +1079,8 @@ def dpnp_multiplication(
10841079
result = _gemm_batch_matmul(exec_q, x1, x2, result)
10851080
else:
10861081
# oneapi::mkl::blas::gemm/gemv do not support integer dtypes,
1087-
# so using dpctl.tensor.matmul instead
1082+
# except for special cases determined in `_gemm_special_case`,
1083+
# use dpctl.tensor.matmul for unsupported cases
10881084

10891085
# `dpt.matmul` does not support `casting` kwarg.
10901086
# We may need to change input dtypes based on given `casting`.
@@ -1096,10 +1092,9 @@ def dpnp_multiplication(
10961092
x1_usm = dpnp.get_usm_ndarray(x1)
10971093
x2_usm = dpnp.get_usm_ndarray(x2)
10981094
out_usm = dpnp.get_usm_ndarray(result)
1099-
res_usm = dpt.matmul(
1095+
dpt.matmul(
11001096
x1_usm, x2_usm, out=out_usm, dtype=dtype, order=order
11011097
)
1102-
result = dpnp_array._create_from_usm_ndarray(res_usm)
11031098

11041099
if NumPy_special_case:
11051100
result = dpnp.tile(result, out.shape)

dpnp/tests/test_product.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,7 @@ def test_order(self, dtype, order1, order2, order, shape1, shape2):
888888
def test_strided1(self, dtype, stride):
889889
for dim in [1, 2, 3, 4]:
890890
shape = tuple(20 for _ in range(dim))
891-
A = numpy.random.rand(*shape).astype(dtype)
891+
A = generate_random_numpy_array(shape, dtype)
892892
iA = dpnp.array(A)
893893
slices = tuple(slice(None, None, stride[i]) for i in range(dim))
894894
a = A[slices]
@@ -897,13 +897,13 @@ def test_strided1(self, dtype, stride):
897897
# the 2D base is not c-contiguous nor f-contigous
898898
result = dpnp.matmul(ia, ia)
899899
expected = numpy.matmul(a, a)
900-
assert_dtype_allclose(result, expected)
900+
assert_dtype_allclose(result, expected, factor=16)
901901

902902
iOUT = dpnp.empty(shape, dtype=result.dtype)
903903
iout = iOUT[slices]
904904
result = dpnp.matmul(ia, ia, out=iout)
905905
assert result is iout
906-
assert_dtype_allclose(result, expected)
906+
assert_dtype_allclose(result, expected, factor=16)
907907

908908
@pytest.mark.parametrize("dtype", _selected_dtypes)
909909
@pytest.mark.parametrize(
@@ -915,7 +915,7 @@ def test_strided2(self, dtype, shape, stride, transpose):
915915
# one dimension (axis=-3) is strided
916916
# if negative stride, copy is needed and the base becomes c-contiguous
917917
# otherwise the base remains the same as input in gemm_batch
918-
A = numpy.random.rand(*shape).astype(dtype)
918+
A = generate_random_numpy_array(shape, dtype)
919919
iA = dpnp.array(A)
920920
if transpose:
921921
A = numpy.moveaxis(A, (-2, -1), (-1, -2))
@@ -948,7 +948,7 @@ def test_strided3(self, dtype, stride, transpose):
948948
# For positive stride, no copy but reshape makes the base c-contiguous.
949949
stride0, stride1 = stride
950950
shape = (12, 10, 3, 3) # 4D array
951-
A = numpy.random.rand(*shape).astype(dtype)
951+
A = generate_random_numpy_array(shape, dtype)
952952
iA = dpnp.array(A)
953953
if transpose:
954954
A = numpy.moveaxis(A, (-2, -1), (-1, -2))
@@ -980,7 +980,7 @@ def test_strided_mat_vec(self, dtype, func, incx, incy, transpose):
980980
else:
981981
s1 = shape[-1]
982982
s2 = shape[-2]
983-
a = numpy.random.rand(*shape).astype(dtype)
983+
a = generate_random_numpy_array(shape, dtype)
984984
ia = dpnp.array(a)
985985
if transpose:
986986
a = numpy.moveaxis(a, (-2, -1), (-1, -2))
@@ -1016,7 +1016,7 @@ def test_strided_vec_mat(self, dtype, func, incx, incy, transpose):
10161016
else:
10171017
s1 = shape[-1]
10181018
s2 = shape[-2]
1019-
a = numpy.random.rand(*shape).astype(dtype)
1019+
a = generate_random_numpy_array(shape, dtype)
10201020
ia = dpnp.array(a)
10211021
if transpose:
10221022
a = numpy.moveaxis(a, (-2, -1), (-1, -2))

0 commit comments

Comments
 (0)