Skip to content

Commit 8fdfc1b

Browse files
author
Diptorup Deb
authored
Merge pull request #1283 from IntelPython/feature/improve_parfor_ufunc_testing
Improve built-in ops and dpnp universal functions testing
2 parents 1608e43 + 6a341d2 commit 8fdfc1b

File tree

8 files changed

+531
-443
lines changed

8 files changed

+531
-443
lines changed

numba_dpex/core/typing/dpnpdecl.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,16 @@ def install_operations(cls):
179179
"ldexp",
180180
"spacing",
181181
"isnat",
182+
"cbrt",
182183
]
183184
)
184185

185-
# A list of ufuncs that are in fact aliases of other ufuncs. They need to insert
186-
# the resolve method, but not register the ufunc itself
187-
_aliases = set(["bitwise_not", "mod", "abs"])
186+
# TODO: A list of ufuncs that are in fact aliases of other ufuncs. They need
187+
# to insert the resolve method, but not register the ufunc itself.
188+
# In a meantime let's just register them as user functions:
189+
# TODO: for some reason it affects "mod", but does not affect "bitwise_not" and
190+
# "abs". May be mod is not an alias?
191+
_aliases = {"bitwise_not", "abs"}
188192

189193
all_ufuncs = sum(
190194
[

numba_dpex/tests/_helper.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
# SPDX-License-Identifier: Apache-2.0
66

77
import contextlib
8+
import inspect
89
import shutil
910
from functools import cache
1011

1112
import dpctl
1213
import dpnp
1314
import pytest
1415

15-
from numba_dpex import dpjit, numba_sem_version
16-
from numba_dpex.core import config
16+
from numba_dpex import config, dpjit, numba_sem_version
1717

1818

1919
@cache
@@ -179,6 +179,14 @@ def get_complex_dtypes(device=None):
179179
return dtypes
180180

181181

182+
def get_int_dtypes(device=None):
183+
"""
184+
Build a list of integer types supported by DPNP based on device capabilities.
185+
"""
186+
187+
return [dpnp.int32, dpnp.int64]
188+
189+
182190
def get_float_dtypes(no_float16=True, device=None):
183191
"""
184192
Build a list of floating types supported by DPNP based on device capabilities.
@@ -227,7 +235,7 @@ def get_all_dtypes(
227235

228236
# add integer types
229237
if not no_int:
230-
dtypes.extend([dpnp.int32, dpnp.int64])
238+
dtypes.extend(get_int_dtypes(device=dev))
231239

232240
# add floating types
233241
if not no_float:
@@ -276,3 +284,20 @@ def skip_if_dtype_not_supported(dt, q_or_dev):
276284
pytest.skip(
277285
f"{dev.name} does not support half precision floating point type"
278286
)
287+
288+
289+
def num_required_arguments(func):
290+
"""Returns number of required arguments of the functions. Does not work
291+
with kwargs arguments."""
292+
if func == dpnp.true_divide:
293+
func = dpnp.divide
294+
295+
sig = inspect.signature(func)
296+
params = sig.parameters
297+
required_args = [
298+
p
299+
for p in params
300+
if params[p].default == inspect._empty and p != "kwargs"
301+
]
302+
303+
return len(required_args)

0 commit comments

Comments
 (0)