Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .github/workflows/array-api-skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,10 @@ array_api_tests/test_linalg.py::test_svd
array_api_tests/test_linalg.py::test_qr
array_api_tests/test_operators_and_elementwise_functions.py::test_clip

# unexpected result is returned
# unexpected result is returned - unmute when dpctl-1986 is resolved
array_api_tests/test_operators_and_elementwise_functions.py::test_asin
array_api_tests/test_operators_and_elementwise_functions.py::test_asinh

# missing 'correction' keyword argument
array_api_tests/test_signatures.py::test_func_signature[std]
array_api_tests/test_signatures.py::test_func_signature[var]

# arrays have different values
array_api_tests/test_linalg.py::test_linalg_tensordot
2 changes: 1 addition & 1 deletion .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ jobs:
id: run_tests_linux
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
with:
timeout_minutes: 12
timeout_minutes: 15
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
retry_on: any
command: |
Expand Down
149 changes: 70 additions & 79 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,12 @@

def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
"""
Determines the output array data type and an intermediate data type
used in performing calculations related to a specific math function.
Determines the output array data type.
If dtype is ``None``, the output array data type of the operation is
determined based on the Promotion Type Rule and device capabilities.
Otherwise, `dtype` is used as output array dtype, if input arrays
can cast to it according to the casting rule determined. If casting
cannot be done, a ``TypeError`` is raised.
The intermediate data type is the data type used for performing the math
function calculations. If output array dtype is a floating-point data type,
it is also used for the intermediate data type. If output array dtype is an
integral data type, the default floating point data type of the device where
input arrays are allocated on are used for intermediate data type.

Parameters
----------
Expand All @@ -78,17 +72,13 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):

Returns
-------
compute_dtype, res_dtype :
`compute_dtype` is the data type used in performing math function calculations.
The input arrays of the math function are cast to `compute_dtype` and then
the calculations are performed.
res_dtype :
`res_dtype` is the output data type. When the result is obtained, it is cast
to `res_dtype`.

"""

res_dtype = dpnp.result_type(*arrays)
default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue)

if dtype is not None:
if dpnp.can_cast(res_dtype, dtype, casting=casting):
Expand All @@ -98,11 +88,7 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
f"Cannot cast from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}"
)

compute_dtype = (
res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype
)

return compute_dtype, res_dtype
return res_dtype


def _copy_array(x, copy_flag=False, dtype=None, order="C"):
Expand Down Expand Up @@ -749,17 +735,17 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
_validate_out_array(out, exec_q)

# Determine the appropriate data types
dot_dtype, res_dtype = _compute_res_dtype(a, b, sycl_queue=exec_q)
res_dtype = _compute_res_dtype(a, b, sycl_queue=exec_q)

result = _create_result_array(
a, b, out, (), dot_dtype, res_usm_type, exec_q
a, b, out, (), res_dtype, res_usm_type, exec_q
)

# input arrays should have the proper data type
if dpnp.issubdtype(res_dtype, dpnp.inexact):
# copying is needed if dtypes of input arrays are different
a = _copy_array(a, dtype=dot_dtype)
b = _copy_array(b, dtype=dot_dtype)
a = _copy_array(a, dtype=res_dtype)
b = _copy_array(b, dtype=res_dtype)

_manager = dpu.SequentialOrderManager[exec_q]

Expand All @@ -777,14 +763,11 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
)
_manager.add_event_pair(ht_ev, dot_ev)
else:
# oneapi::mkl::blas::dot is slow for integer data type,
# oneapi::mkl::blas::dot does not support integer dtypes,
# so using dpctl.tensor.vecdot instead
dpt_a = dpnp.get_usm_ndarray(a)
dpt_b = dpnp.get_usm_ndarray(b)
result = dpnp_array._create_from_usm_ndarray(dpt.vecdot(dpt_a, dpt_b))

if dot_dtype != res_dtype:
result = result.astype(res_dtype, copy=False)
a_usm = dpnp.get_usm_ndarray(a)
b_usm = dpnp.get_usm_ndarray(b)
result = dpnp_array._create_from_usm_ndarray(dpt.vecdot(a_usm, b_usm))

return dpnp.get_result_array(result, out, casting=casting)

Expand Down Expand Up @@ -902,7 +885,7 @@ def dpnp_multiplication(
axes_res = normalize_axis_tuple(axes_res, len(result_shape), "axes")

# Determine the appropriate data types
compute_dtype, res_dtype = _compute_res_dtype(
res_dtype = _compute_res_dtype(
x1, x2, dtype=dtype, casting=casting, sycl_queue=exec_q
)

Expand Down Expand Up @@ -998,7 +981,7 @@ def dpnp_multiplication(
x2,
out,
res_shape,
compute_dtype,
res_dtype,
res_usm_type,
exec_q,
res_order,
Expand All @@ -1010,64 +993,72 @@ def dpnp_multiplication(
elif x1.size == 0 or x2.size == 0:
result.fill(0)
else:
# input arrays should have the proper data type and
# their base (last 2-dimensions) to be c-contiguous or f-contiguous
x1 = _copy_array(
x1,
copy_flag=not x1_contig_flag,
dtype=compute_dtype,
order=res_order,
)
x2 = _copy_array(
x2,
copy_flag=not x2_contig_flag,
dtype=compute_dtype,
order=res_order,
)

if call_flag == "gemv":
if transpose:
a_usm = dpnp.get_usm_ndarray(x2)
x_usm = dpnp.get_usm_ndarray(x1)
else:
a_usm = dpnp.get_usm_ndarray(x1)
x_usm = dpnp.get_usm_ndarray(x2)

_manager = dpu.SequentialOrderManager[exec_q]

ht_ev, gemv_ev = bi._gemv(
exec_q,
a_usm,
x_usm,
dpnp.get_usm_ndarray(result),
transpose,
depends=_manager.submitted_events,
)
_manager.add_event_pair(ht_ev, gemv_ev)
elif call_flag == "gemm":
result = _gemm_matmul(
exec_q,
if dpnp.issubdtype(res_dtype, dpnp.inexact):
# copying is needed if dtypes of input arrays are different or
# their base (last 2-dimensions) is not c-contiguous or f-contiguous
x1 = _copy_array(
x1,
x2,
result,
copy_flag=not x1_contig_flag,
dtype=res_dtype,
order=res_order,
)
else: # call_flag == "gemm_batch"
assert call_flag == "gemm_batch"
result = _gemm_batch_matmul(
exec_q,
x1,
x2 = _copy_array(
x2,
result,
copy_flag=not x2_contig_flag,
dtype=res_dtype,
order=res_order,
)

if call_flag == "gemv":
if transpose:
a_usm = dpnp.get_usm_ndarray(x2)
x_usm = dpnp.get_usm_ndarray(x1)
else:
a_usm = dpnp.get_usm_ndarray(x1)
x_usm = dpnp.get_usm_ndarray(x2)

_manager = dpu.SequentialOrderManager[exec_q]

ht_ev, gemv_ev = bi._gemv(
exec_q,
a_usm,
x_usm,
dpnp.get_usm_ndarray(result),
transpose,
depends=_manager.submitted_events,
)
_manager.add_event_pair(ht_ev, gemv_ev)
elif call_flag == "gemm":
result = _gemm_matmul(
exec_q,
x1,
x2,
result,
)
else: # call_flag == "gemm_batch"
assert call_flag == "gemm_batch"
result = _gemm_batch_matmul(
exec_q,
x1,
x2,
result,
)
else:
# oneapi::mkl::blas::gemm/gemv do not support integer dtypes,
# so using dpctl.tensor.matmul instead
x1_usm = dpnp.get_usm_ndarray(x1)
x2_usm = dpnp.get_usm_ndarray(x2)
out_usm = dpnp.get_usm_ndarray(result)
res_usm = dpt.matmul(
x1_usm, x2_usm, out=out_usm, dtype=dtype, order=order
)
result = dpnp_array._create_from_usm_ndarray(res_usm)

if NumPy_special_case:
result = dpnp.tile(result, out.shape)
elif res_shape != result_shape:
result = dpnp.reshape(result, result_shape)

if compute_dtype != res_dtype:
result = dpnp.astype(result, res_dtype, copy=False)

if out is None:
if axes is not None:
# Move the data back to the appropriate axes of the result array
Expand Down Expand Up @@ -1207,7 +1198,7 @@ def dpnp_vecdot(
)

# Determine the appropriate data types
_, res_dtype = _compute_res_dtype(
res_dtype = _compute_res_dtype(
x1, x2, dtype=dtype, casting=casting, sycl_queue=exec_q
)

Expand Down
Loading
Loading