Skip to content

Commit 5898507

Browse files
author
Vahid Tavanashad
committed
use syrk for boolean dtypes when possible
1 parent 03d08c5 commit 5898507

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,10 @@ def dpnp_multiplication(
986986
if _is_syrk_compatible(x1, x2):
987987
call_flag = "syrk"
988988
res_dtype_orig = res_dtype
989-
if dpnp.issubdtype(res_dtype, dpnp.integer):
989+
# for exact dtypes, use syrk implementation unlike general approach
990+
# where dpctl implementation is used for exact dtypes for better
991+
# performance
992+
if not dpnp.issubdtype(res_dtype, dpnp.inexact):
990993
res_dtype = dpnp.default_float_type(x1.device)
991994
elif x1_base_is_1D:
992995
# TODO: implement gemv_batch to use it here with transpose

dpnp/tests/test_product.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,7 @@ def test_special_case(self, dt_out, shape1, shape2):
11711171
result = dpnp.matmul(ia, ib, out=iout)
11721172
assert_dtype_allclose(result, expected)
11731173

1174-
@pytest.mark.parametrize("dt", get_all_dtypes())
1174+
@pytest.mark.parametrize("dt", get_all_dtypes(no_none=True))
11751175
def test_syrk(self, dt):
11761176
a = generate_random_numpy_array((6, 9), dtype=dt, low=-5, high=5)
11771177
ia = dpnp.array(a)

0 commit comments

Comments
 (0)