Skip to content

Commit 4b2145c

Browse files
Merge pull request #1008 from IntelPython/flatiter-impl
Flatiter impl
2 parents c5fee4b + 8a02ac0 commit 4b2145c

File tree

4 files changed

+102
-4
lines changed

4 files changed

+102
-4
lines changed

dpnp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242

4343
# from dpnp.dparray import dparray as ndarray
4444
from dpnp.dpnp_array import dpnp_array as ndarray
45+
from dpnp.dpnp_flatiter import flatiter as flatiter
46+
4547
from dpnp.dpnp_iface import *
4648
from dpnp.dpnp_iface import __all__ as _iface__all__
4749
from dpnp.dpnp_iface_types import *

dpnp/dpnp_array.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,16 @@ def dtype(self):
397397
# 'dumps',
398398
# 'fill',
399399
# 'flags',
400-
# 'flat',
400+
401+
@property
402+
def flat(self):
403+
"""
404+
Return a flat iterator, or set a flattened version of self to value.
405+
406+
"""
407+
408+
return dpnp.flatiter(self)
409+
401410
# 'flatten',
402411
# 'getfield',
403412
# 'imag',

dpnp/dpnp_flatiter.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# -*- coding: utf-8 -*-
2+
# *****************************************************************************
3+
# Copyright (c) 2016-2020, Intel Corporation
4+
# All rights reserved.
5+
#
6+
# Redistribution and use in source and binary forms, with or without
7+
# modification, are permitted provided that the following conditions are met:
8+
# - Redistributions of source code must retain the above copyright notice,
9+
# this list of conditions and the following disclaimer.
10+
# - Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
18+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
23+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
24+
# THE POSSIBILITY OF SUCH DAMAGE.
25+
# *****************************************************************************
26+
27+
"""
28+
Implementation of flatiter
29+
30+
"""
31+
32+
import dpnp
33+
34+
class flatiter:
35+
def __init__(self, X):
36+
if type(X) is not dpnp.ndarray:
37+
raise TypeError(
38+
"Argument must be of type dpnp.ndarray, got {}".format(
39+
type(X)
40+
)
41+
)
42+
self.arr_ = X
43+
self.size_ = X.size
44+
self.i_ = 0
45+
46+
def _multiindex(self, i):
47+
nd = self.arr_.ndim
48+
if nd == 0:
49+
if i == 0:
50+
return tuple()
51+
raise KeyError
52+
elif nd == 1:
53+
return (i,)
54+
sh = self.arr_.shape
55+
i_ = i
56+
multi_index = [0] * nd
57+
for k in reversed(range(1, nd)):
58+
si = sh[k]
59+
q = i_ // si
60+
multi_index[k] = i_ - q * si
61+
i_ = q
62+
multi_index[0] = i_
63+
return tuple(multi_index)
64+
65+
def __getitem__(self, key):
66+
idx = getattr(key, "__index__", None)
67+
if not callable(idx):
68+
raise TypeError(key)
69+
i = idx()
70+
mi = self._multiindex(i)
71+
return self.arr_.__getitem__(mi)
72+
73+
def __setitem__(self, key, val):
74+
idx = getattr(key, "__index__", None)
75+
if not callable(idx):
76+
raise TypeError(key)
77+
i = idx()
78+
mi = self._multiindex(i)
79+
return self.arr_.__setitem__(mi, val)
80+
81+
def __iter__(self):
82+
return self
83+
84+
def __next__(self):
85+
if self.i_ < self.size_:
86+
val = self.__getitem__(self.i_)
87+
self.i_ = self.i_ + 1
88+
return val
89+
else:
90+
raise StopIteration

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,6 @@ tests/test_fft.py::test_fft[float32]
365365
tests/test_fft.py::test_fft[float64]
366366
tests/test_fft.py::test_fft[int32]
367367
tests/test_fft.py::test_fft[int64]
368-
tests/test_flat.py::test_flat[int64]
369-
tests/test_flat.py::test_flat2[int64]
370-
tests/test_flat.py::test_flat3[int64]
371368
tests/test_linalg.py::test_cond[-1-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
372369
tests/test_linalg.py::test_cond[1-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
373370
tests/test_linalg.py::test_cond[-2-[[1, 0, -1], [0, 1, 0], [1, 0, 1]]]

0 commit comments

Comments
 (0)