Skip to content

Commit 0a03265

Browse files
authored
Support testing of USMArray (#894)
* Add patching of 2 numpy asserts * Add assert_equal to assert usm_ndarray * Add converting tuple(usm_ndarray) to tuple(ndarray) for testing
1 parent a390ca7 commit 0a03265

File tree

4 files changed

+65
-0
lines changed

4 files changed

+65
-0
lines changed

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ __all__ = [
5555
"checker_throw_type_error",
5656
"checker_throw_value_error",
5757
"create_output_descriptor_py",
58+
"convert_item",
5859
"dpnp_descriptor",
5960
"get_axis_indeces",
6061
"get_axis_offsets",

tests/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import dpnp
2+
import numpy
3+
4+
from tests import testing
5+
6+
7+
numpy.testing.assert_allclose = testing.assert_allclose
8+
numpy.testing.assert_array_equal = testing.assert_array_equal
9+
numpy.testing.assert_equal = testing.assert_equal

tests/testing/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from tests.testing.array import assert_allclose
2+
from tests.testing.array import assert_array_equal
3+
from tests.testing.array import assert_equal

tests/testing/array.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# -*- coding: utf-8 -*-
2+
# *****************************************************************************
3+
# Copyright (c) 2016-2020, Intel Corporation
4+
# All rights reserved.
5+
#
6+
# Redistribution and use in source and binary forms, with or without
7+
# modification, are permitted provided that the following conditions are met:
8+
# - Redistributions of source code must retain the above copyright notice,
9+
# this list of conditions and the following disclaimer.
10+
# - Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
18+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
23+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
24+
# THE POSSIBILITY OF SUCH DAMAGE.
25+
# *****************************************************************************
26+
27+
import numpy
28+
from dpnp.dpnp_utils import convert_item
29+
30+
31+
assert_allclose_orig = numpy.testing.assert_allclose
32+
assert_array_equal_orig = numpy.testing.assert_array_equal
33+
assert_equal_orig = numpy.testing.assert_equal
34+
35+
36+
def _assert(assert_func, result, expected, *args, **kwargs):
37+
result = convert_item(result)
38+
expected = convert_item(expected)
39+
40+
assert_func(result, expected, *args, **kwargs)
41+
42+
43+
def assert_allclose(result, expected, *args, **kwargs):
44+
_assert(assert_allclose_orig, result, expected, *args, **kwargs)
45+
46+
47+
def assert_array_equal(result, expected, *args, **kwargs):
48+
_assert(assert_array_equal_orig, result, expected, *args, **kwargs)
49+
50+
51+
def assert_equal(result, expected, *args, **kwargs):
52+
_assert(assert_equal_orig, result, expected, *args, **kwargs)

0 commit comments

Comments
 (0)