|
64 | 64 | "indices", |
65 | 65 | "ix_", |
66 | 66 | "mask_indices", |
| 67 | + "ndindex", |
67 | 68 | "nonzero", |
68 | 69 | "place", |
69 | 70 | "put", |
@@ -1057,6 +1058,77 @@ def mask_indices( |
1057 | 1058 | return nonzero(a != 0) |
1058 | 1059 |
|
1059 | 1060 |
|
| 1061 | +# pylint: disable=invalid-name |
| 1062 | +# pylint: disable=too-few-public-methods |
| 1063 | +class ndindex: |
| 1064 | + """ |
| 1065 | + An N-dimensional iterator object to index arrays. |
| 1066 | +
|
| 1067 | + Given the shape of an array, an :obj:`dpnp.ndindex` instance iterates over |
| 1068 | + the N-dimensional index of the array. At each iteration a tuple of indices |
| 1069 | + is returned, the last dimension is iterated over first. |
| 1070 | +
|
| 1071 | + For full documentation refer to :obj:`numpy.ndindex`. |
| 1072 | +
|
| 1073 | + Parameters |
| 1074 | + ---------- |
| 1075 | + shape : ints, or a single tuple of ints |
| 1076 | + The size of each dimension of the array can be passed as individual |
| 1077 | + parameters or as the elements of a tuple. |
| 1078 | +
|
| 1079 | + See Also |
| 1080 | + -------- |
| 1081 | + :obj:`dpnp.ndenumerate` : Multidimensional index iterator. |
| 1082 | + :obj:`dpnp.flatiter` : Flat iterator object to iterate over arrays. |
| 1083 | +
|
| 1084 | + Examples |
| 1085 | + -------- |
| 1086 | + >>> import dpnp as np |
| 1087 | +
|
| 1088 | + Dimensions as individual arguments |
| 1089 | +
|
| 1090 | + >>> for index in np.ndindex(3, 2, 1): |
| 1091 | + ... print(index) |
| 1092 | + (0, 0, 0) |
| 1093 | + (0, 1, 0) |
| 1094 | + (1, 0, 0) |
| 1095 | + (1, 1, 0) |
| 1096 | + (2, 0, 0) |
| 1097 | + (2, 1, 0) |
| 1098 | +
|
| 1099 | + Same dimensions - but in a tuple ``(3, 2, 1)`` |
| 1100 | +
|
| 1101 | + >>> for index in np.ndindex((3, 2, 1)): |
| 1102 | + ... print(index) |
| 1103 | + (0, 0, 0) |
| 1104 | + (0, 1, 0) |
| 1105 | + (1, 0, 0) |
| 1106 | + (1, 1, 0) |
| 1107 | + (2, 0, 0) |
| 1108 | + (2, 1, 0) |
| 1109 | +
|
| 1110 | + """ |
| 1111 | + |
| 1112 | + def __init__(self, *shape): |
| 1113 | + self.ndindex_ = numpy.ndindex(*shape) |
| 1114 | + |
| 1115 | + def __iter__(self): |
| 1116 | + return self.ndindex_ |
| 1117 | + |
| 1118 | + def __next__(self): |
| 1119 | + """ |
| 1120 | + Standard iterator method, updates the index and returns the index tuple. |
| 1121 | +
|
| 1122 | + Returns |
| 1123 | + ------- |
| 1124 | + val : tuple of ints |
| 1125 | + Returns a tuple containing the indices of the current iteration. |
| 1126 | +
|
| 1127 | + """ |
| 1128 | + |
| 1129 | + return self.ndindex_.__next__() |
| 1130 | + |
| 1131 | + |
1060 | 1132 | def nonzero(a): |
1061 | 1133 | """ |
1062 | 1134 | Return the indices of the elements that are non-zero. |
|
0 commit comments