Skip to content

Commit a8a93da

Browse files
Call implementation from python
1 parent 7ed9c7e commit a8a93da

File tree

4 files changed

+153
-50
lines changed

4 files changed

+153
-50
lines changed

dpnp/backend/extensions/statistics/kth_element1d.cpp

Lines changed: 108 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
#include "kth_element1d.hpp"
4141
#include "partitioning.hpp"
4242

43-
// #include <iostream>
43+
#include <iostream>
44+
#include <chrono>
4445

4546
namespace sycl_exp = sycl::ext::oneapi::experimental;
4647
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
@@ -67,6 +68,7 @@ struct KthElementF
6768
State<T> &state,
6869
uint64_t items_to_sort,
6970
uint64_t limit,
71+
bool ret,
7072
const std::vector<sycl::event> &deps)
7173
{
7274
auto e = queue.submit([&](sycl::handler &cgh) {
@@ -84,6 +86,12 @@ struct KthElementF
8486
auto scratch = sycl::local_accessor<std::byte, 1>(
8587
sycl::range<1>(temp_memory_size), cgh);
8688

89+
// std::cout << "temp_memory_size: " << temp_memory_size
90+
// << " items_to_sort: " << items_to_sort
91+
// << " limit: " << limit
92+
// << " group_size: " << group_size << "\n";
93+
94+
// auto str = sycl::stream(8192, 1024, cgh);
8795
cgh.parallel_for<pick_pivot_kernel<T>>(
8896
work_sz, [=](sycl::nd_item<1> item) {
8997
auto group = item.get_group();
@@ -129,7 +137,6 @@ struct KthElementF
129137
target_found = true;
130138
}
131139
}
132-
133140
state.reset_iteration_counters();
134141
}
135142

@@ -142,10 +149,15 @@ struct KthElementF
142149
return;
143150
}
144151

152+
// if (group.leader()) {
153+
// str << "num_elems: " << num_elems << "\n";
154+
// }
155+
145156
if (num_elems <= limit) {
146157
auto gh = sycl_exp::group_with_scratchpad(
147158
group, sycl::span{&scratch[0], temp_memory_size});
148-
sycl_exp::joint_sort(gh, &_in[0], &_in[num_elems]);
159+
if (num_elems > 0)
160+
sycl_exp::joint_sort(gh, &_in[0], &_in[num_elems]);
149161

150162
if (group.leader()) {
151163
uint64_t offset = state.counters.less_count[0];
@@ -154,9 +166,18 @@ struct KthElementF
154166
state.counters.less_count[0] - num_elems;
155167
}
156168

157-
uint64_t idx = target - offset;
158-
state.values[0] = _in[idx];
159-
state.values[1] = _in[idx + 1];
169+
int64_t idx = target - offset;
170+
171+
// if (idx + 1 > (in + state.n - _in) || idx < 0)
172+
// {
173+
// str << "buffer access out of bounds idx = "
174+
// << idx << " size " << (in + state.n - _in) << "\n";
175+
// }
176+
// else
177+
{
178+
state.values[0] = _in[idx];
179+
state.values[1] = _in[idx + 1];
180+
}
160181

161182
state.stop[0] = true;
162183
state.target_found[0] = true;
@@ -165,6 +186,9 @@ struct KthElementF
165186
return;
166187
}
167188

189+
// if (ret)
190+
// return;
191+
168192
uint64_t step = num_elems / items_to_sort;
169193
for (uint32_t i = llid; i < items_to_sort; i += local_size)
170194
{
@@ -184,30 +208,30 @@ struct KthElementF
184208

185209
T new_pivot = loc_items[items_to_sort / 2];
186210

187-
if (new_pivot != state.pivot[0]) {
211+
// if (new_pivot != state.pivot[0]) {
188212
if (group.leader()) {
189213
state.pivot[0] = new_pivot;
190214
state.num_elems[0] = num_elems;
191215
}
192216
return;
193-
}
194-
195-
auto start = llid + items_to_sort / 2 + 1;
196-
uint32_t index = start;
197-
for (uint32_t i = start; i < items_to_sort; i += local_size)
198-
{
199-
if (loc_items[i] != new_pivot) {
200-
index = i;
201-
break;
202-
}
203-
}
204-
205-
index = sycl::reduce_over_group(group, index,
206-
sycl::minimum<>());
207-
if (group.leader()) {
208-
state.pivot[0] = loc_items[index];
209-
state.num_elems[0] = num_elems;
210-
}
217+
// }
218+
219+
// auto start = llid + items_to_sort / 2 + 1;
220+
// uint32_t index = start;
221+
// for (uint32_t i = start; i < items_to_sort; i += local_size)
222+
// {
223+
// if (loc_items[i] != new_pivot) {
224+
// index = i;
225+
// break;
226+
// }
227+
// }
228+
229+
// index = sycl::reduce_over_group(group, index,
230+
// sycl::minimum<>());
231+
// if (group.leader()) {
232+
// state.pivot[0] = loc_items[index];
233+
// state.num_elems[0] = num_elems;
234+
// }
211235
});
212236
});
213237

@@ -225,7 +249,7 @@ struct KthElementF
225249
auto e = exec_q.submit([&](sycl::handler &cgh) {
226250
cgh.depends_on(deps);
227251

228-
constexpr uint32_t WorkPI = 4; // empirically found number
252+
constexpr uint32_t WorkPI = 1; // empirically found number
229253

230254
auto work_range = make_ndrange(state.n, group_size, WorkPI);
231255
submit_partition_one_pivot<T, WorkPI>(cgh, work_range, in, out,
@@ -243,47 +267,57 @@ struct KthElementF
243267
PartitionState<T> &pstate,
244268
const std::vector<sycl::event> &depends)
245269
{
246-
uint32_t items_to_sort = 128;
247-
uint32_t limit = 4 * items_to_sort;
270+
uint32_t items_to_sort = 127;
271+
uint32_t limit = 4 * (items_to_sort + 1);
248272
uint32_t iterations =
249-
std::ceil(std::log(double(state.n) / limit) / std::log(2));
273+
std::ceil(-std::log(double(state.n) / limit) / std::log(0.536)) + 1;
274+
// Ensure iterations are odd so the final result is always stored in 'partitioned'
275+
iterations += 1 - iterations % 2;
250276

251277
auto temp_buff = dpctl_utils::smart_malloc<T>(state.n, exec_q,
252278
sycl::usm::alloc::device);
253279

280+
std::cout << "Iteration " << 0 << std::endl;
254281
auto prev = run_pick_pivot(exec_q, const_cast<T *>(in), partitioned, k,
255-
state, items_to_sort, limit, depends);
282+
state, items_to_sort, limit, false, depends);
256283
prev = run_partition(exec_q, const_cast<T *>(in), partitioned, pstate,
257284
{prev});
285+
// prev.wait();
258286

259287
T *_in = partitioned;
260288
T *_out = temp_buff.get();
261289
for (uint32_t i = 0; i < iterations - 1; ++i) {
262-
prev = run_pick_pivot(exec_q, _in, _out, k, state, limit,
263-
items_to_sort, {prev});
290+
std::cout << "Iteration " << i + 1 << std::endl;
291+
prev = run_pick_pivot(exec_q, _in, _out, k, state,
292+
items_to_sort, limit, true, {prev});
264293
prev = run_partition(exec_q, _in, _out, pstate, {prev});
265294
std::swap(_in, _out);
295+
// prev.wait();
296+
// if (i % 5 == 0)
297+
// prev.wait();
266298
}
267-
prev = run_pick_pivot(exec_q, _in, _out, k, state, limit, items_to_sort,
268-
{prev});
299+
prev = run_pick_pivot(exec_q, _in, _out, k, state, items_to_sort, limit,
300+
true, {prev});
269301

270302
return prev;
271303
}
272304

273-
static std::tuple<bool, uint64_t, uint64_t, uint64_t>
305+
static KthElement1d::RetT
274306
impl(sycl::queue &exec_queue,
275307
const void *v_ain,
276308
void *v_partitioned,
277309
const size_t a_size,
278310
const size_t k,
279311
const std::vector<sycl::event> &depends)
280312
{
313+
auto start = std::chrono::high_resolution_clock::now();
281314
const T *ain = static_cast<const T *>(v_ain);
282315
T *partitioned = static_cast<T *>(v_partitioned);
283316

284317
State<T> state(exec_queue, a_size, partitioned);
285318
PartitionState<T> pstate(state);
286319

320+
exec_queue.wait();
287321
auto init_e = state.init(exec_queue, depends);
288322
init_e = pstate.init(exec_queue, {init_e});
289323

@@ -295,34 +329,60 @@ struct KthElementF
295329
uint64_t less_count = 0;
296330
uint64_t greater_equal_count = 0;
297331
uint64_t num_elems = 0;
332+
uint64_t nan_count = 0;
298333
auto copy_evt = exec_queue.copy(state.target_found, &found, 1, evt);
299334
copy_evt = exec_queue.copy(state.left, &left, 1, copy_evt);
300335
copy_evt = exec_queue.copy(state.counters.less_count, &less_count, 1,
301336
copy_evt);
302337
copy_evt = exec_queue.copy(state.counters.greater_equal_count,
303338
&greater_equal_count, 1, copy_evt);
304339
copy_evt = exec_queue.copy(state.num_elems, &num_elems, 1, copy_evt);
305-
306-
copy_evt.wait();
340+
copy_evt = exec_queue.copy(state.counters.nan_count, &nan_count, 1, copy_evt);
307341

308342
uint64_t buff_offset = 0;
309343
uint64_t elems_offset = less_count;
310-
if (!found) {
311-
if (left) {
312-
elems_offset = less_count - num_elems;
344+
345+
try
346+
{
347+
copy_evt.wait();
348+
349+
if (!found) {
350+
if (left) {
351+
elems_offset = less_count - num_elems;
352+
}
353+
else {
354+
buff_offset = a_size - num_elems;
355+
}
313356
}
314357
else {
315-
buff_offset = a_size - num_elems;
358+
num_elems = 2;
359+
elems_offset = k;
316360
}
361+
362+
state.cleanup(exec_queue);
363+
auto end = std::chrono::high_resolution_clock::now();
364+
365+
auto duration =
366+
std::chrono::duration_cast<std::chrono::microseconds>(end - start)
367+
.count();
368+
369+
std::cout << "KthElement1d took " << duration << " microseconds"
370+
<< std::endl;
371+
372+
std::cout << "Found " << found << " left " << left
373+
<< " less_count " << less_count
374+
<< " greater_equal_count " << greater_equal_count
375+
<< " num_elems " << num_elems
376+
<< " nan_count " << nan_count
377+
<< std::endl;
378+
/* code */
317379
}
318-
else {
319-
num_elems = 2;
320-
elems_offset = k;
380+
catch (sycl::exception const &e)
381+
{
382+
std::cout << e.what() << std::endl;
321383
}
322384

323-
state.cleanup(exec_queue);
324-
325-
return {found, buff_offset, elems_offset, num_elems};
385+
return {found, buff_offset, elems_offset, num_elems, nan_count};
326386
}
327387
};
328388

@@ -335,7 +395,7 @@ KthElement1d::KthElement1d() : dispatch_table("a")
335395
dispatch_table.populate_dispatch_table<SupportedTypes, KthElementF>();
336396
}
337397

338-
std::tuple<bool, uint64_t, uint64_t, uint64_t>
398+
KthElement1d::RetT
339399
KthElement1d::call(const dpctl::tensor::usm_ndarray &a,
340400
dpctl::tensor::usm_ndarray &partitioned,
341401
const size_t k,

dpnp/backend/extensions/statistics/kth_element1d.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ namespace statistics::partitioning
3333
{
3434
struct KthElement1d
3535
{
36-
using FnT = std::tuple<bool, uint64_t, uint64_t, uint64_t> (*)(
36+
using RetT = std::tuple<bool, uint64_t, uint64_t, uint64_t, uint64_t>;
37+
using FnT = RetT (*)(
3738
sycl::queue &,
3839
const void *,
3940
void *,
@@ -45,7 +46,7 @@ struct KthElement1d
4546

4647
KthElement1d();
4748

48-
std::tuple<bool, uint64_t, uint64_t, uint64_t>
49+
RetT
4950
call(const dpctl::tensor::usm_ndarray &a,
5051
dpctl::tensor::usm_ndarray &partitioned,
5152
uint64_t k,

dpnp/backend/extensions/statistics/partitioning.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ void submit_partition_one_pivot(sycl::handler &cgh,
205205
{
206206
auto loc_counters =
207207
sycl::local_accessor<uint32_t, 1>(sycl::range<1>(4), cgh);
208+
// sycl::stream str(8192, 1024, cgh);
208209
cgh.parallel_for<partition_one_pivot_kernel<T>>(
209210
work_sz, [=](sycl::nd_item<1> item) {
210211
if (state.stop[0])

dpnp/dpnp_utils/dpnp_utils_statistics.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
import dpctl.tensor as dpt
3030
from dpctl.tensor._numpy_helper import normalize_axis_tuple
3131
from dpctl.utils import ExecutionPlacementError
32+
import dpnp.backend.extensions.statistics._statistics_impl as statistics_ext
33+
34+
import dpctl.utils as dpu
3235

3336
import dpnp
3437
from dpnp.dpnp_array import dpnp_array
@@ -190,6 +193,41 @@ def dpnp_cov(
190193
c = dpnp.dot(x, x_t.conj()) / fact
191194
return c.squeeze()
192195

196+
def native_median(a):
197+
198+
partitioned = dpnp.empty_like(a)
199+
a_usm = dpnp.get_usm_ndarray(a)
200+
partitioned_usm = dpnp.get_usm_ndarray(partitioned)
201+
202+
_manager = dpu.SequentialOrderManager[a.sycl_queue]
203+
204+
result = dpnp.empty_like(a, shape = 1)
205+
k = a.shape[0] // 2
206+
207+
found, buff_offset, elems_offset, num_elems, nan_count = statistics_ext.kth_element(
208+
a_usm,
209+
partitioned_usm,
210+
k,
211+
depends=_manager.submitted_events,
212+
)
213+
214+
if found:
215+
if a.shape[0] % 2 == 0:
216+
# even number of elements
217+
result[0] = (partitioned[0] + partitioned[1]) / 2
218+
else:
219+
result[0] = partitioned[0]
220+
else:
221+
partitioned[buff_offset:buff_offset + num_elems].sort()
222+
kth_idx = buff_offset + k - elems_offset
223+
if a.shape[0] % 2 == 0:
224+
# even number of elements
225+
result[0] = (partitioned[kth_idx] + partitioned[kth_idx + 1]) / 2
226+
else:
227+
result[0] = partitioned[kth_idx]
228+
229+
return result
230+
193231

194232
def dpnp_median(
195233
a,
@@ -223,6 +261,9 @@ def dpnp_median(
223261
)
224262
axis = -1
225263

264+
if not isinstance(a.dtype, dpnp.complexfloating) and not ignore_nan and a_ndim == 1:
265+
return native_median(a)
266+
226267
if overwrite_input:
227268
if isinstance(a, dpt.usm_ndarray):
228269
# dpnp.ndarray.sort only works with dpnp_array

0 commit comments

Comments
 (0)