Skip to content

Commit 4411a82

Browse files
authored
add kernel for max with axis (#698)
* add kernel for max with axis
1 parent 3baa476 commit 4411a82

File tree

4 files changed

+242
-86
lines changed

4 files changed

+242
-86
lines changed

dpnp/backend/kernels/dpnp_krnl_statistics.cpp

Lines changed: 171 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -146,36 +146,188 @@ class dpnp_max_c_kernel;
146146
template <typename _DataType>
147147
void dpnp_max_c(void* array1_in, void* result1, const size_t* shape, size_t ndim, const size_t* axis, size_t naxis)
148148
{
149-
__attribute__((unused)) void* tmp = (void*)(axis + naxis);
149+
if (naxis == 0)
150+
{
151+
__attribute__((unused)) void* tmp = (void*)(axis + naxis);
150152

151-
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
152-
_DataType* result = reinterpret_cast<_DataType*>(result1);
153+
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
154+
_DataType* result = reinterpret_cast<_DataType*>(result1);
153155

154-
size_t size = 1;
155-
for (size_t i = 0; i < ndim; ++i)
156-
{
157-
size *= shape[i];
158-
}
156+
size_t size = 1;
157+
for (size_t i = 0; i < ndim; ++i)
158+
{
159+
size *= shape[i];
160+
}
159161

160-
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
161-
{
162-
// Required initializing the result before call the function
163-
result[0] = array_1[0];
162+
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
163+
{
164+
// Required initializing the result before call the function
165+
result[0] = array_1[0];
164166

165-
auto dataset = mkl_stats::make_dataset<mkl_stats::layout::row_major>(1, size, array_1);
167+
auto dataset = mkl_stats::make_dataset<mkl_stats::layout::row_major>(1, size, array_1);
166168

167-
cl::sycl::event event = mkl_stats::max(DPNP_QUEUE, dataset, result);
169+
cl::sycl::event event = mkl_stats::max(DPNP_QUEUE, dataset, result);
168170

169-
event.wait();
171+
event.wait();
172+
}
173+
else
174+
{
175+
auto policy = oneapi::dpl::execution::make_device_policy<class dpnp_max_c_kernel<_DataType>>(DPNP_QUEUE);
176+
177+
_DataType* res = std::max_element(policy, array_1, array_1 + size);
178+
policy.queue().wait();
179+
180+
result[0] = *res;
181+
}
170182
}
171183
else
172184
{
173-
auto policy = oneapi::dpl::execution::make_device_policy<class dpnp_max_c_kernel<_DataType>>(DPNP_QUEUE);
185+
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
186+
_DataType* result = reinterpret_cast<_DataType*>(result1);
187+
188+
size_t res_ndim = ndim - naxis;
189+
size_t res_shape[res_ndim];
190+
int ind = 0;
191+
for (size_t i = 0; i < ndim; i++)
192+
{
193+
bool found = false;
194+
for (size_t j = 0; j < naxis; j++)
195+
{
196+
if (axis[j] == i)
197+
{
198+
found = true;
199+
break;
200+
}
201+
}
202+
if (!found)
203+
{
204+
res_shape[ind] = shape[i];
205+
ind++;
206+
}
207+
}
208+
209+
size_t size_input = 1;
210+
for (size_t i = 0; i < ndim; ++i)
211+
{
212+
size_input *= shape[i];
213+
}
214+
215+
size_t input_shape_offsets[ndim];
216+
size_t acc = 1;
217+
for (size_t i = ndim - 1; i > 0; --i)
218+
{
219+
input_shape_offsets[i] = acc;
220+
acc *= shape[i];
221+
}
222+
input_shape_offsets[0] = acc;
223+
224+
size_t output_shape_offsets[res_ndim];
225+
acc = 1;
226+
if (res_ndim > 0)
227+
{
228+
for (size_t i = res_ndim - 1; i > 0; --i)
229+
{
230+
output_shape_offsets[i] = acc;
231+
acc *= res_shape[i];
232+
}
233+
}
234+
output_shape_offsets[0] = acc;
235+
236+
size_t size_result = 1;
237+
for (size_t i = 0; i < res_ndim; ++i)
238+
{
239+
size_result *= res_shape[i];
240+
}
174241

175-
_DataType* res = std::max_element(policy, array_1, array_1 + size);
176-
policy.queue().wait();
242+
//init result array
243+
for (size_t result_idx = 0; result_idx < size_result; ++result_idx)
244+
{
245+
size_t xyz[res_ndim];
246+
size_t remainder = result_idx;
247+
for (size_t i = 0; i < res_ndim; ++i)
248+
{
249+
xyz[i] = remainder / output_shape_offsets[i];
250+
remainder = remainder - xyz[i] * output_shape_offsets[i];
251+
}
252+
253+
size_t source_axis[ndim];
254+
size_t result_axis_idx = 0;
255+
for (size_t idx = 0; idx < ndim; ++idx)
256+
{
257+
bool found = false;
258+
for (size_t i = 0; i < naxis; ++i)
259+
{
260+
if (axis[i] == idx)
261+
{
262+
found = true;
263+
break;
264+
}
265+
}
266+
if (found)
267+
{
268+
source_axis[idx] = 0;
269+
}
270+
else
271+
{
272+
source_axis[idx] = xyz[result_axis_idx];
273+
result_axis_idx++;
274+
}
275+
}
276+
277+
size_t source_idx = 0;
278+
for (size_t i = 0; i < ndim; ++i)
279+
{
280+
source_idx += input_shape_offsets[i] * source_axis[i];
281+
}
282+
283+
result[result_idx] = array_1[source_idx];
284+
}
177285

178-
result[0] = *res;
286+
for (size_t source_idx = 0; source_idx < size_input; ++source_idx)
287+
{
288+
// reconstruct x,y,z from linear source_idx
289+
size_t xyz[ndim];
290+
size_t remainder = source_idx;
291+
for (size_t i = 0; i < ndim; ++i)
292+
{
293+
xyz[i] = remainder / input_shape_offsets[i];
294+
remainder = remainder - xyz[i] * input_shape_offsets[i];
295+
}
296+
297+
// extract result axis
298+
size_t result_axis[res_ndim];
299+
size_t result_idx = 0;
300+
for (size_t idx = 0; idx < ndim; ++idx)
301+
{
302+
// try to find current idx in axis array
303+
bool found = false;
304+
for (size_t i = 0; i < naxis; ++i)
305+
{
306+
if (axis[i] == idx)
307+
{
308+
found = true;
309+
break;
310+
}
311+
}
312+
if (!found)
313+
{
314+
result_axis[result_idx] = xyz[idx];
315+
result_idx++;
316+
}
317+
}
318+
319+
// Construct result offset
320+
size_t result_offset = 0;
321+
for (size_t i = 0; i < res_ndim; ++i)
322+
{
323+
result_offset += output_shape_offsets[i] * result_axis[i];
324+
}
325+
326+
if (result[result_offset] < array_1[source_idx])
327+
{
328+
result[result_offset] = array_1[source_idx];
329+
}
330+
}
179331
}
180332

181333
return;

dpnp/dpnp_algo/dpnp_algo_statistics.pyx

Lines changed: 38 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -144,89 +144,60 @@ cpdef dparray dpnp_cov(dparray array1):
144144
return result
145145

146146

147-
cpdef dparray _dpnp_max(dparray input):
147+
cpdef dparray _dpnp_max(dparray input, _axis_, output_shape):
148148
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
149149

150150
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MAX, param1_type, param1_type)
151151

152-
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
153-
cdef dparray result = dparray((1,), dtype=result_type)
152+
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
153+
cdef dparray result = dparray(output_shape, dtype=result_type)
154154

155155
cdef custom_statistic_1in_1out_func_ptr_t func = <custom_statistic_1in_1out_func_ptr_t > kernel_data.ptr
156-
157-
# stub for interface support
158156
cdef dparray_shape_type axis
159157
cdef Py_ssize_t axis_size = 0
158+
cdef dparray_shape_type axis_ = axis
160159

161-
func(input.get_data(), result.get_data(), < size_t * > input._dparray_shape.data(), input.ndim, < size_t * > axis.data(), axis_size)
160+
if _axis_ is not None:
161+
axis = _axis_
162+
axis_.reserve(len(axis))
163+
for shape_it in axis:
164+
axis_.push_back(shape_it)
165+
axis_size = len(axis)
162166

163-
return result
167+
func(input.get_data(), result.get_data(), < size_t * > input._dparray_shape.data(), input.ndim, < size_t * > axis_.data(), axis_size)
164168

169+
dpnp_array = dpnp.array(result, dtype=input.dtype)
170+
dpnp_result_array = dpnp_array.reshape(output_shape)
171+
return dpnp_result_array
165172

166-
cpdef dparray dpnp_max(dparray input, axis):
167-
if axis is None:
168-
return _dpnp_max(input)
169173

174+
cpdef dparray dpnp_max(dparray input, axis):
170175
cdef dparray_shape_type shape_input = input.shape
171-
cdef long size_input = input.size
172-
if isinstance(axis, int):
173-
axis_ = tuple([axis])
174-
else:
176+
if axis is None:
175177
axis_ = axis
176-
177-
output_shape = dparray(len(shape_input) - len(axis_), dtype=numpy.int64)
178-
ind = 0
179-
for id, shape_axis in enumerate(shape_input):
180-
if id not in axis_:
181-
output_shape[ind] = shape_axis
182-
ind += 1
183-
cdef long prod = 1
184-
for i in range(len(output_shape)):
185-
if output_shape[i] != 0:
186-
prod *= output_shape[i]
187-
result_array = [None] * prod
188-
input_shape_offsets = [None] * len(shape_input)
189-
acc = 1
190-
for i in range(len(shape_input)):
191-
ind = len(shape_input) - 1 - i
192-
input_shape_offsets[ind] = acc
193-
acc *= shape_input[ind]
194-
output_shape_offsets = [None] * len(shape_input)
195-
acc = 1
196-
if len(output_shape) > 0:
197-
for i in range(len(output_shape)):
198-
ind = len(output_shape) - 1 - i
199-
output_shape_offsets[ind] = acc
200-
acc *= output_shape[ind]
201-
202-
for source_idx in range(size_input):
203-
204-
# reconstruct x,y,z from linear source_idx
205-
xyz = []
206-
remainder = source_idx
207-
for i in input_shape_offsets:
208-
quotient, remainder = divmod(remainder, i)
209-
xyz.append(quotient)
210-
211-
# extract result axis
212-
result_axis = []
213-
for idx, offset in enumerate(xyz):
214-
if idx not in axis_:
215-
result_axis.append(offset)
216-
217-
# Construct result offset
218-
result_offset = 0
219-
for i, result_axis_val in enumerate(result_axis):
220-
result_offset += (output_shape_offsets[i] * result_axis_val)
221-
222-
input_elem = input.item(source_idx)
223-
if result_array[result_offset] is None:
224-
result_array[result_offset] = input_elem
178+
output_shape = 1
179+
else:
180+
if isinstance(axis, int):
181+
if axis < 0:
182+
axis_ = tuple([input.ndim - axis])
183+
else:
184+
axis_ = tuple([axis])
225185
else:
226-
result_array[result_offset] = max(result_array[result_offset], input_elem)
227-
dpnp_array = dpnp.array(result_array, dtype=input.dtype)
228-
dpnp_result_array = dpnp_array.reshape(output_shape)
229-
return dpnp_result_array
186+
_axis_ = []
187+
for i in range(len(axis)):
188+
if axis[i] < 0:
189+
_axis_.append(input.ndim - axis[i])
190+
else:
191+
_axis_.append(axis[i])
192+
axis_ = tuple(_axis_)
193+
194+
output_shape = dparray(len(shape_input) - len(axis_), dtype=numpy.int64)
195+
ind = 0
196+
for id, shape_axis in enumerate(shape_input):
197+
if id not in axis_:
198+
output_shape[ind] = shape_axis
199+
ind += 1
200+
return _dpnp_max(input, axis_, output_shape)
230201

231202

232203
cpdef dparray _dpnp_mean(dparray input):

dpnp/dpnp_iface_statistics.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,30 @@ def max(input, axis=None, out=None, keepdims=numpy._NoValue, initial=numpy._NoVa
302302
"""
303303

304304
if not use_origin_backend(input):
305+
# Negative values in 'shape' are not allowed in dparray
306+
# 306-322 check on negative and duplicate axis
307+
isaxis = True
308+
if axis is not None:
309+
if dpnp.isscalar(axis):
310+
if axis < 0:
311+
isaxis = False
312+
else:
313+
for val in axis:
314+
if val < 0:
315+
isaxis = False
316+
break
317+
if isaxis:
318+
for i in range(len(axis)):
319+
for j in range(len(axis)):
320+
if i != j:
321+
if axis[i] == axis[j]:
322+
isaxis = False
323+
break
324+
305325
if not isinstance(input, dparray):
306326
pass
327+
elif not isaxis:
328+
pass
307329
elif out is not None:
308330
pass
309331
elif keepdims is not numpy._NoValue:

tests/test_statistics.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@ def test_median(type, size):
2020
numpy.testing.assert_allclose(dpnp_res, np_res)
2121

2222

23+
@pytest.mark.parametrize("axis",
24+
[0, 1, -1, 2, -2, (1, 2), (0, -2)])
25+
def test_max(axis):
26+
a = numpy.arange(768, dtype=numpy.float64).reshape((4, 4, 6, 8))
27+
ia = dpnp.array(a)
28+
29+
np_res = numpy.max(a, axis=axis)
30+
dpnp_res = dpnp.max(ia, axis=axis)
31+
32+
numpy.testing.assert_allclose(dpnp_res, np_res)
33+
2334
@pytest.mark.parametrize("array",
2435
[[2, 0, 6, 2],
2536
[2, 0, 6, 2, 5, 6, 7, 8],

0 commit comments

Comments
 (0)