Skip to content

Commit 9119ae9

Browse files
Crash due to buffer deletion and add support for complex type
1 parent a518e0c commit 9119ae9

File tree

3 files changed

+62
-96
lines changed

3 files changed

+62
-96
lines changed

dpnp/backend/extensions/statistics/kth_element1d.cpp

Lines changed: 50 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ struct KthElementF
6868
State<T> &state,
6969
uint64_t items_to_sort,
7070
uint64_t limit,
71-
bool ret,
7271
const std::vector<sycl::event> &deps)
7372
{
7473
auto e = queue.submit([&](sycl::handler &cgh) {
@@ -149,15 +148,11 @@ struct KthElementF
149148
return;
150149
}
151150

152-
// if (group.leader()) {
153-
// str << "num_elems: " << num_elems << "\n";
154-
// }
155-
156151
if (num_elems <= limit) {
157152
auto gh = sycl_exp::group_with_scratchpad(
158153
group, sycl::span{&scratch[0], temp_memory_size});
159154
if (num_elems > 0)
160-
sycl_exp::joint_sort(gh, &_in[0], &_in[num_elems]);
155+
sycl_exp::joint_sort(gh, &_in[0], &_in[num_elems], Less<T>{});
161156

162157
if (group.leader()) {
163158
uint64_t offset = state.counters.less_count[0];
@@ -168,16 +163,8 @@ struct KthElementF
168163

169164
int64_t idx = target - offset;
170165

171-
// if (idx + 1 > (in + state.n - _in) || idx < 0)
172-
// {
173-
// str << "buffer access out of bounds idx = "
174-
// << idx << " size " << (in + state.n - _in) << "\n";
175-
// }
176-
// else
177-
{
178-
state.values[0] = _in[idx];
179-
state.values[1] = _in[idx + 1];
180-
}
166+
state.values[0] = _in[idx];
167+
state.values[1] = _in[idx + 1];
181168

182169
state.stop[0] = true;
183170
state.target_found[0] = true;
@@ -186,9 +173,6 @@ struct KthElementF
186173
return;
187174
}
188175

189-
// if (ret)
190-
// return;
191-
192176
uint64_t step = num_elems / items_to_sort;
193177
for (uint32_t i = llid; i < items_to_sort; i += local_size)
194178
{
@@ -204,34 +188,34 @@ struct KthElementF
204188
auto gh = sycl_exp::group_with_scratchpad(
205189
group, sycl::span{&scratch[0], temp_memory_size});
206190
sycl_exp::joint_sort(gh, &loc_items[0],
207-
&loc_items[0] + items_to_sort);
191+
&loc_items[0] + items_to_sort, Less<T>{});
208192

209193
T new_pivot = loc_items[items_to_sort / 2];
210194

211-
// if (new_pivot != state.pivot[0]) {
195+
if (new_pivot != state.pivot[0]) {
212196
if (group.leader()) {
213197
state.pivot[0] = new_pivot;
214198
state.num_elems[0] = num_elems;
215199
}
216200
return;
217-
// }
218-
219-
// auto start = llid + items_to_sort / 2 + 1;
220-
// uint32_t index = start;
221-
// for (uint32_t i = start; i < items_to_sort; i += local_size)
222-
// {
223-
// if (loc_items[i] != new_pivot) {
224-
// index = i;
225-
// break;
226-
// }
227-
// }
228-
229-
// index = sycl::reduce_over_group(group, index,
230-
// sycl::minimum<>());
231-
// if (group.leader()) {
232-
// state.pivot[0] = loc_items[index];
233-
// state.num_elems[0] = num_elems;
234-
// }
201+
}
202+
203+
auto start = llid + items_to_sort / 2 + 1;
204+
uint32_t index = start;
205+
for (uint32_t i = start; i < items_to_sort; i += local_size)
206+
{
207+
if (loc_items[i] != new_pivot) {
208+
index = i;
209+
break;
210+
}
211+
}
212+
213+
index = sycl::reduce_over_group(group, index,
214+
sycl::minimum<>());
215+
if (group.leader()) {
216+
state.pivot[0] = loc_items[index];
217+
state.num_elems[0] = num_elems;
218+
}
235219
});
236220
});
237221

@@ -262,6 +246,7 @@ struct KthElementF
262246
static sycl::event run_kth_element(sycl::queue &exec_q,
263247
const T *in,
264248
T *partitioned,
249+
T *temp_buff,
265250
const size_t k,
266251
State<T> &state,
267252
PartitionState<T> &pstate,
@@ -274,31 +259,21 @@ struct KthElementF
274259
// Ensure iterations are odd so the final result is always stored in 'partitioned'
275260
iterations += 1 - iterations % 2;
276261

277-
auto temp_buff = dpctl_utils::smart_malloc<T>(state.n, exec_q,
278-
sycl::usm::alloc::device);
279-
280-
std::cout << "Iteration " << 0 << std::endl;
281262
auto prev = run_pick_pivot(exec_q, const_cast<T *>(in), partitioned, k,
282-
state, items_to_sort, limit, false, depends);
263+
state, items_to_sort, limit, depends);
283264
prev = run_partition(exec_q, const_cast<T *>(in), partitioned, pstate,
284265
{prev});
285-
// prev.wait();
286266

287267
T *_in = partitioned;
288-
T *_out = temp_buff.get();
268+
T *_out = temp_buff;
289269
for (uint32_t i = 0; i < iterations - 1; ++i) {
290-
std::cout << "Iteration " << i + 1 << std::endl;
291270
prev = run_pick_pivot(exec_q, _in, _out, k, state,
292-
items_to_sort, limit, true, {prev});
271+
items_to_sort, limit, {prev});
293272
prev = run_partition(exec_q, _in, _out, pstate, {prev});
294273
std::swap(_in, _out);
295-
// prev.wait();
296-
// if (i % 5 == 0)
297-
// prev.wait();
298274
}
299-
prev.wait();
300275
prev = run_pick_pivot(exec_q, _in, _out, k, state, items_to_sort, limit,
301-
true, {prev});
276+
{prev});
302277

303278
return prev;
304279
}
@@ -322,7 +297,9 @@ struct KthElementF
322297
auto init_e = state.init(exec_queue, depends);
323298
init_e = pstate.init(exec_queue, {init_e});
324299

325-
auto evt = run_kth_element(exec_queue, ain, partitioned, k, state,
300+
auto temp_buff = dpctl_utils::smart_malloc<T>(state.n, exec_queue,
301+
sycl::usm::alloc::device);
302+
auto evt = run_kth_element(exec_queue, ain, partitioned, temp_buff.get(), k, state,
326303
pstate, {init_e});
327304

328305
bool found = false;
@@ -343,52 +320,37 @@ struct KthElementF
343320
uint64_t buff_offset = 0;
344321
uint64_t elems_offset = less_count;
345322

346-
try
347-
{
348-
copy_evt.wait();
349-
350-
if (!found) {
351-
if (left) {
352-
elems_offset = less_count - num_elems;
353-
}
354-
else {
355-
buff_offset = a_size - num_elems;
356-
}
323+
copy_evt.wait();
324+
325+
if (!found) {
326+
if (left) {
327+
elems_offset = less_count - num_elems;
357328
}
358329
else {
359-
num_elems = 2;
360-
elems_offset = k;
330+
buff_offset = a_size - num_elems;
361331
}
362-
363-
auto end = std::chrono::high_resolution_clock::now();
364-
365-
auto duration =
366-
std::chrono::duration_cast<std::chrono::microseconds>(end - start)
367-
.count();
368-
369-
std::cout << "KthElement1d took " << duration << " microseconds"
370-
<< std::endl;
371-
372-
std::cout << "Found " << found << " left " << left
373-
<< " less_count " << less_count
374-
<< " greater_equal_count " << greater_equal_count
375-
<< " num_elems " << num_elems
376-
<< " nan_count " << nan_count
377-
<< std::endl;
378-
/* code */
379332
}
380-
catch (sycl::exception const &e)
381-
{
382-
std::cout << e.what() << std::endl;
333+
else {
334+
num_elems = 2;
335+
elems_offset = k;
383336
}
384337

385338
state.cleanup(exec_queue);
339+
340+
auto end = std::chrono::high_resolution_clock::now();
341+
342+
auto duration =
343+
std::chrono::duration_cast<std::chrono::microseconds>(end - start)
344+
.count();
345+
346+
std::cout << "KthElement1d took " << duration << " microseconds"
347+
<< std::endl;
386348
return {found, buff_offset, elems_offset, num_elems, nan_count};
387349
}
388350
};
389351

390352
using SupportedTypes =
391-
std::tuple<uint32_t, int32_t, uint64_t, int64_t, float, double>;
353+
std::tuple<uint32_t, int32_t, uint64_t, int64_t, float, double, std::complex<float>, std::complex<double>>;
392354
} // namespace
393355

394356
KthElement1d::KthElement1d() : dispatch_table("a")

dpnp/backend/extensions/statistics/partitioning.hpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,15 @@ void submit_partition_one_pivot(sycl::handler &cgh,
256256
auto i = local_i_base + _i * sbg_size;
257257
if (i < num_elems) {
258258
values[_i] = _in[i];
259-
less_count += Less<T>{}(values[_i], value);
260-
equal_count += values[_i] == value;
261-
nan_count += IsNan<T>::isnan(values[_i]);
259+
auto is_nan = IsNan<T>::isnan(values[_i]);
260+
less_count += (Less<T>{}(values[_i], value) && !is_nan);
261+
equal_count += (values[_i] == value && !is_nan);
262+
nan_count += is_nan;
262263
actual_count++;
263264
}
264265
}
265266

266-
greater_equal_count = actual_count - less_count;
267+
greater_equal_count = actual_count - less_count - nan_count;
267268

268269
auto sbg_less_equal =
269270
sycl::reduce_over_group(sbg, less_count, sycl::plus<>());
@@ -329,21 +330,24 @@ void submit_partition_one_pivot(sycl::handler &cgh,
329330
uint32_t gr_item_offset = 0;
330331

331332
for (uint32_t _i = 0; _i < WorkPI; ++_i) {
332-
uint32_t less = values[_i] < value;
333+
uint32_t is_nan = IsNan<T>::isnan(values[_i]);
334+
uint32_t less = (!is_nan && Less<T>{}(values[_i], value));
333335
auto le_pos =
334336
sycl::exclusive_scan_over_group(sbg, less, sycl::plus<>());
335337
auto ge_pos = sbg.get_local_linear_id() - le_pos;
336338

337339
auto total_le =
338340
sycl::reduce_over_group(sbg, less, sycl::plus<>());
339-
auto total_gr = sbg_size - total_le;
341+
auto total_nan =
342+
sycl::reduce_over_group(sbg, is_nan, sycl::plus<>());
343+
auto total_gr = sbg_size - total_le - total_nan;
340344

341345
if (_i < actual_count) {
342346
if (less) {
343347
out[sbg_less_offset + le_item_offset + le_pos] =
344348
values[_i];
345349
}
346-
else {
350+
else if (!is_nan){
347351
out[sbg_gr_offset + gr_item_offset + ge_pos] =
348352
values[_i];
349353
}

dpnp/dpnp_utils/dpnp_utils_statistics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def dpnp_median(
261261
)
262262
axis = -1
263263

264-
if not isinstance(a.dtype, dpnp.complexfloating) and not ignore_nan and a_ndim == 1:
264+
if not ignore_nan and a_ndim == 1:
265265
return native_median(a)
266266

267267
if overwrite_input:

0 commit comments

Comments
 (0)