Skip to content

Commit 8842fa6

Browse files
Add integer datatypes and add types to tests
1 parent c3c194c commit 8842fa6

File tree

5 files changed

+67
-14
lines changed

5 files changed

+67
-14
lines changed

.github/workflows/conda-package.yml

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ jobs:
241241
${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-
242242
243243
- name: Install dpnp
244-
run: mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
244+
run: mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest pytest-xdist python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
245245
env:
246246
TEST_CHANNELS: '-c ${{ env.channel-path }} ${{ env.CHANNELS }}'
247247
MAMBA_NO_LOW_SPEED_LIMIT: 1
@@ -257,7 +257,8 @@ jobs:
257257
- name: Run tests
258258
if: env.RERUN_TESTS_ON_FAILURE != 'true'
259259
run: |
260-
python -m pytest -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
260+
export DPNP_TEST_ALL_TYPES=1
261+
python -m pytest -n auto -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
261262
working-directory: ${{ env.tests-path }}
262263

263264
- name: Run tests
@@ -266,14 +267,15 @@ jobs:
266267
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
267268
with:
268269
shell: bash
269-
timeout_minutes: 10
270+
timeout_minutes: 45
270271
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
271272
retry_on: any
272273
command: |
273274
. $CONDA/etc/profile.d/conda.sh
274275
conda activate ${{ env.TEST_ENV_NAME }}
275276
cd ${{ env.tests-path }}
276-
python -m pytest -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
277+
export DPNP_TEST_ALL_TYPES=1
278+
python -m pytest -n auto -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
277279
278280
test_windows:
279281
name: Test ['windows-2019', python='${{ matrix.python }}']
@@ -387,7 +389,7 @@ jobs:
387389
- name: Install dpnp
388390
run: |
389391
@echo on
390-
mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
392+
mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest pytest-xdist python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
391393
env:
392394
TEST_CHANNELS: '-c ${{ env.channel-path }} ${{ env.CHANNELS }}'
393395
MAMBA_NO_LOW_SPEED_LIMIT: 1
@@ -412,7 +414,8 @@ jobs:
412414
- name: Run tests
413415
if: env.RERUN_TESTS_ON_FAILURE != 'true'
414416
run: |
415-
python -m pytest -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
417+
set DPNP_TEST_ALL_TYPES=1
418+
python -m pytest -n auto -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
416419
working-directory: ${{ env.tests-path }}
417420

418421
- name: Run tests
@@ -421,13 +424,14 @@ jobs:
421424
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
422425
with:
423426
shell: cmd
424-
timeout_minutes: 15
427+
timeout_minutes: 45
425428
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
426429
retry_on: any
427430
command: >-
428-
mamba activate ${{ env.TEST_ENV_NAME }}
431+
set DPNP_TEST_ALL_TYPES=1
432+
& mamba activate ${{ env.TEST_ENV_NAME }}
429433
& cd ${{ env.tests-path }}
430-
& python -m pytest -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
434+
& python -m pytest -n auto -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
431435
432436
upload:
433437
name: Upload ['${{ matrix.os }}', python='${{ matrix.python }}']

dpnp/dpnp_iface_types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,14 @@
5959
"inf",
6060
"int",
6161
"int_",
62+
"int8",
63+
"int16",
6264
"int32",
6365
"int64",
66+
"uint8",
67+
"uint16",
68+
"uint32",
69+
"uint64",
6470
"integer",
6571
"intc",
6672
"intp",
@@ -95,8 +101,14 @@
95101
inexact = numpy.inexact
96102
int = numpy.int_
97103
int_ = numpy.int_
104+
int8 = numpy.int8
105+
int16 = numpy.int16
98106
int32 = numpy.int32
99107
int64 = numpy.int64
108+
uint8 = numpy.uint8
109+
uint16 = numpy.uint16
110+
uint32 = numpy.uint32
111+
uint64 = numpy.uint64
100112
integer = numpy.integer
101113
intc = numpy.intc
102114
intp = numpy.intp

tests/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import os
2+
3+
all_types = int(os.getenv("DPNP_TEST_ALL_TYPES", 0))

tests/helper.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from numpy.testing import assert_allclose, assert_array_equal
66

77
import dpnp
8+
from tests import config
89

910

1011
def assert_dtype_allclose(
@@ -88,6 +89,18 @@ def get_integer_dtypes():
8889
Build a list of integer types supported by DPNP.
8990
"""
9091

92+
if config.all_types:
93+
return [
94+
dpnp.int8,
95+
dpnp.int16,
96+
dpnp.int32,
97+
dpnp.int64,
98+
dpnp.uint8,
99+
dpnp.uint16,
100+
dpnp.uint32,
101+
dpnp.uint64,
102+
]
103+
91104
return [dpnp.int32, dpnp.int64]
92105

93106

tests/third_party/cupy/testing/_loops.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from dpctl.tensor._numpy_helper import AxisError
1313

1414
import dpnp as cupy
15+
from tests import config
1516
from tests.third_party.cupy.testing import _array, _parameterized
1617
from tests.third_party.cupy.testing._pytest_impl import is_available
1718

@@ -1039,19 +1040,39 @@ def _get_supported_complex_dtypes():
10391040
return (numpy.complex64,)
10401041

10411042

1043+
def _get_int_dtypes():
1044+
if config.all_types:
1045+
return _signed_dtypes + _unsigned_dtypes
1046+
else:
1047+
return (numpy.int64, numpy.int32)
1048+
1049+
10421050
_complex_dtypes = _get_supported_complex_dtypes()
10431051
_regular_float_dtypes = _get_supported_float_dtypes()
1044-
_float_dtypes = _regular_float_dtypes
1045-
_signed_dtypes = ()
1052+
_float_dtypes = _regular_float_dtypes + (numpy.float16,)
1053+
_signed_dtypes = tuple(numpy.dtype(i).type for i in "bhilq")
10461054
_unsigned_dtypes = tuple(numpy.dtype(i).type for i in "BHILQ")
1047-
_int_dtypes = _signed_dtypes + _unsigned_dtypes
1048-
_int_bool_dtypes = _int_dtypes
1055+
_int_dtypes = _get_int_dtypes()
1056+
_int_bool_dtypes = _int_dtypes + (numpy.bool_,)
10491057
_regular_dtypes = _regular_float_dtypes + _int_bool_dtypes
10501058
_dtypes = _float_dtypes + _int_bool_dtypes
10511059

10521060

10531061
def _make_all_dtypes(no_float16, no_bool, no_complex):
1054-
return (numpy.int64, numpy.int32) + _get_supported_float_dtypes()
1062+
if no_float16:
1063+
dtypes = _regular_float_dtypes
1064+
else:
1065+
dtypes = _float_dtypes
1066+
1067+
if no_bool:
1068+
dtypes += _int_dtypes
1069+
else:
1070+
dtypes += _int_bool_dtypes
1071+
1072+
if not no_complex:
1073+
dtypes += _complex_dtypes
1074+
1075+
return dtypes
10551076

10561077

10571078
def for_all_dtypes(

0 commit comments

Comments
 (0)