Skip to content

Commit f0de1c6

Browse files
Copy inputs if they have strides (#941)
* Copy inputs if they have strides Co-authored-by: Oleksandr Pavlyk <[email protected]>
1 parent 10ccd07 commit f0de1c6

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

dpnp/dpnp_iface.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def convert_single_elem_array_to_scalar(obj, keepdims=False):
150150
return obj
151151

152152

153-
def get_dpnp_descriptor(ext_obj):
153+
def get_dpnp_descriptor(ext_obj, copy_when_strides=True):
154154
"""
155155
Return True:
156156
never
@@ -169,6 +169,14 @@ def get_dpnp_descriptor(ext_obj):
169169
if use_origin_backend():
170170
return False
171171

172+
# while dpnp functions have no implementation with strides support
173+
# we need to create a non-strided copy
174+
# if function get implementation for strides case
175+
# then this behavior can be disabled with setting "copy_when_strides"
176+
if copy_when_strides and getattr(ext_obj, "strides", None) is not None:
177+
# TODO: replace this workaround when usm_ndarray will provide such functionality
178+
ext_obj = array(ext_obj)
179+
172180
dpnp_desc = dpnp_descriptor(ext_obj)
173181
if dpnp_desc.is_valid:
174182
return dpnp_desc

tests/test_strides.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
3+
import dpnp
4+
import numpy
5+
6+
def _getattr(ex, str_):
7+
attrs = str_.split(".")
8+
res = ex
9+
for attr in attrs:
10+
res = getattr(res, attr)
11+
return res
12+
13+
@pytest.mark.parametrize("func_name",
14+
['abs',])
15+
@pytest.mark.parametrize("type",
16+
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
17+
ids=['float64', 'float32', 'int64', 'int32'])
18+
def test_strides(func_name, type):
19+
shape = (4, 4)
20+
a = numpy.arange(shape[0] * shape[1], dtype=type).reshape(shape)
21+
a_strides = a[0::2, 0::2]
22+
dpa = dpnp.array(a)
23+
dpa_strides = dpa[0::2, 0::2]
24+
25+
dpnp_func = _getattr(dpnp, func_name)
26+
result = dpnp_func(dpa_strides)
27+
28+
numpy_func = _getattr(numpy, func_name)
29+
expected = numpy_func(a_strides)
30+
31+
numpy.testing.assert_allclose(expected, result)

0 commit comments

Comments
 (0)