Skip to content

Commit db26d6b

Browse files
pre-commit
1 parent 9119ae9 commit db26d6b

File tree

4 files changed

+56
-48
lines changed

4 files changed

+56
-48
lines changed

dpnp/backend/extensions/statistics/kth_element1d.cpp

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
#include "kth_element1d.hpp"
4141
#include "partitioning.hpp"
4242

43-
#include <iostream>
4443
#include <chrono>
44+
#include <iostream>
4545

4646
namespace sycl_exp = sycl::ext::oneapi::experimental;
4747
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
@@ -152,7 +152,8 @@ struct KthElementF
152152
auto gh = sycl_exp::group_with_scratchpad(
153153
group, sycl::span{&scratch[0], temp_memory_size});
154154
if (num_elems > 0)
155-
sycl_exp::joint_sort(gh, &_in[0], &_in[num_elems], Less<T>{});
155+
sycl_exp::joint_sort(gh, &_in[0], &_in[num_elems],
156+
Less<T>{});
156157

157158
if (group.leader()) {
158159
uint64_t offset = state.counters.less_count[0];
@@ -188,7 +189,8 @@ struct KthElementF
188189
auto gh = sycl_exp::group_with_scratchpad(
189190
group, sycl::span{&scratch[0], temp_memory_size});
190191
sycl_exp::joint_sort(gh, &loc_items[0],
191-
&loc_items[0] + items_to_sort, Less<T>{});
192+
&loc_items[0] + items_to_sort,
193+
Less<T>{});
192194

193195
T new_pivot = loc_items[items_to_sort / 2];
194196

@@ -256,7 +258,8 @@ struct KthElementF
256258
uint32_t limit = 4 * (items_to_sort + 1);
257259
uint32_t iterations =
258260
std::ceil(-std::log(double(state.n) / limit) / std::log(0.536)) + 1;
259-
// Ensure iterations are odd so the final result is always stored in 'partitioned'
261+
// Ensure iterations are odd so the final result is always stored in
262+
// 'partitioned'
260263
iterations += 1 - iterations % 2;
261264

262265
auto prev = run_pick_pivot(exec_q, const_cast<T *>(in), partitioned, k,
@@ -267,8 +270,8 @@ struct KthElementF
267270
T *_in = partitioned;
268271
T *_out = temp_buff;
269272
for (uint32_t i = 0; i < iterations - 1; ++i) {
270-
prev = run_pick_pivot(exec_q, _in, _out, k, state,
271-
items_to_sort, limit, {prev});
273+
prev = run_pick_pivot(exec_q, _in, _out, k, state, items_to_sort,
274+
limit, {prev});
272275
prev = run_partition(exec_q, _in, _out, pstate, {prev});
273276
std::swap(_in, _out);
274277
}
@@ -278,13 +281,12 @@ struct KthElementF
278281
return prev;
279282
}
280283

281-
static KthElement1d::RetT
282-
impl(sycl::queue &exec_queue,
283-
const void *v_ain,
284-
void *v_partitioned,
285-
const size_t a_size,
286-
const size_t k,
287-
const std::vector<sycl::event> &depends)
284+
static KthElement1d::RetT impl(sycl::queue &exec_queue,
285+
const void *v_ain,
286+
void *v_partitioned,
287+
const size_t a_size,
288+
const size_t k,
289+
const std::vector<sycl::event> &depends)
288290
{
289291
auto start = std::chrono::high_resolution_clock::now();
290292
const T *ain = static_cast<const T *>(v_ain);
@@ -298,9 +300,9 @@ struct KthElementF
298300
init_e = pstate.init(exec_queue, {init_e});
299301

300302
auto temp_buff = dpctl_utils::smart_malloc<T>(state.n, exec_queue,
301-
sycl::usm::alloc::device);
302-
auto evt = run_kth_element(exec_queue, ain, partitioned, temp_buff.get(), k, state,
303-
pstate, {init_e});
303+
sycl::usm::alloc::device);
304+
auto evt = run_kth_element(exec_queue, ain, partitioned,
305+
temp_buff.get(), k, state, pstate, {init_e});
304306

305307
bool found = false;
306308
bool left = false;
@@ -315,7 +317,8 @@ struct KthElementF
315317
copy_evt = exec_queue.copy(state.counters.greater_equal_count,
316318
&greater_equal_count, 1, copy_evt);
317319
copy_evt = exec_queue.copy(state.num_elems, &num_elems, 1, copy_evt);
318-
copy_evt = exec_queue.copy(state.counters.nan_count, &nan_count, 1, copy_evt);
320+
copy_evt =
321+
exec_queue.copy(state.counters.nan_count, &nan_count, 1, copy_evt);
319322

320323
uint64_t buff_offset = 0;
321324
uint64_t elems_offset = less_count;
@@ -344,25 +347,30 @@ struct KthElementF
344347
.count();
345348

346349
std::cout << "KthElement1d took " << duration << " microseconds"
347-
<< std::endl;
350+
<< std::endl;
348351
return {found, buff_offset, elems_offset, num_elems, nan_count};
349352
}
350353
};
351354

352-
using SupportedTypes =
353-
std::tuple<uint32_t, int32_t, uint64_t, int64_t, float, double, std::complex<float>, std::complex<double>>;
355+
using SupportedTypes = std::tuple<uint32_t,
356+
int32_t,
357+
uint64_t,
358+
int64_t,
359+
float,
360+
double,
361+
std::complex<float>,
362+
std::complex<double>>;
354363
} // namespace
355364

356365
KthElement1d::KthElement1d() : dispatch_table("a")
357366
{
358367
dispatch_table.populate_dispatch_table<SupportedTypes, KthElementF>();
359368
}
360369

361-
KthElement1d::RetT
362-
KthElement1d::call(const dpctl::tensor::usm_ndarray &a,
363-
dpctl::tensor::usm_ndarray &partitioned,
364-
const size_t k,
365-
const std::vector<sycl::event> &depends)
370+
KthElement1d::RetT KthElement1d::call(const dpctl::tensor::usm_ndarray &a,
371+
dpctl::tensor::usm_ndarray &partitioned,
372+
const size_t k,
373+
const std::vector<sycl::event> &depends)
366374
{
367375
// validate(a, partitioned, k);
368376

dpnp/backend/extensions/statistics/kth_element1d.hpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,21 @@ namespace statistics::partitioning
3434
struct KthElement1d
3535
{
3636
using RetT = std::tuple<bool, uint64_t, uint64_t, uint64_t, uint64_t>;
37-
using FnT = RetT (*)(
38-
sycl::queue &,
39-
const void *,
40-
void *,
41-
const size_t,
42-
const size_t,
43-
const std::vector<sycl::event> &);
37+
using FnT = RetT (*)(sycl::queue &,
38+
const void *,
39+
void *,
40+
const size_t,
41+
const size_t,
42+
const std::vector<sycl::event> &);
4443

4544
ext::common::DispatchTable<FnT> dispatch_table;
4645

4746
KthElement1d();
4847

49-
RetT
50-
call(const dpctl::tensor::usm_ndarray &a,
51-
dpctl::tensor::usm_ndarray &partitioned,
52-
uint64_t k,
53-
const std::vector<sycl::event> &depends);
48+
RetT call(const dpctl::tensor::usm_ndarray &a,
49+
dpctl::tensor::usm_ndarray &partitioned,
50+
uint64_t k,
51+
const std::vector<sycl::event> &depends);
5452
};
5553

5654
void populate_kth_element1d(py::module_ m);

dpnp/backend/extensions/statistics/partitioning.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ void submit_partition_one_pivot(sycl::handler &cgh,
347347
out[sbg_less_offset + le_item_offset + le_pos] =
348348
values[_i];
349349
}
350-
else if (!is_nan){
350+
else if (!is_nan) {
351351
out[sbg_gr_offset + gr_item_offset + ge_pos] =
352352
values[_i];
353353
}

dpnp/dpnp_utils/dpnp_utils_statistics.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@
2727

2828
import dpctl
2929
import dpctl.tensor as dpt
30+
import dpctl.utils as dpu
3031
from dpctl.tensor._numpy_helper import normalize_axis_tuple
3132
from dpctl.utils import ExecutionPlacementError
32-
import dpnp.backend.extensions.statistics._statistics_impl as statistics_ext
33-
34-
import dpctl.utils as dpu
3533

3634
import dpnp
35+
import dpnp.backend.extensions.statistics._statistics_impl as statistics_ext
3736
from dpnp.dpnp_array import dpnp_array
3837

3938
__all__ = ["dpnp_cov", "dpnp_median"]
@@ -193,6 +192,7 @@ def dpnp_cov(
193192
c = dpnp.dot(x, x_t.conj()) / fact
194193
return c.squeeze()
195194

195+
196196
def native_median(a):
197197

198198
partitioned = dpnp.empty_like(a)
@@ -201,14 +201,16 @@ def native_median(a):
201201

202202
_manager = dpu.SequentialOrderManager[a.sycl_queue]
203203

204-
result = dpnp.empty_like(a, shape = 1)
204+
result = dpnp.empty_like(a, shape=1)
205205
k = a.shape[0] // 2
206206

207-
found, buff_offset, elems_offset, num_elems, nan_count = statistics_ext.kth_element(
208-
a_usm,
209-
partitioned_usm,
210-
k,
211-
depends=_manager.submitted_events,
207+
found, buff_offset, elems_offset, num_elems, nan_count = (
208+
statistics_ext.kth_element(
209+
a_usm,
210+
partitioned_usm,
211+
k,
212+
depends=_manager.submitted_events,
213+
)
212214
)
213215

214216
if found:
@@ -218,7 +220,7 @@ def native_median(a):
218220
else:
219221
result[0] = partitioned[0]
220222
else:
221-
partitioned[buff_offset:buff_offset + num_elems].sort()
223+
partitioned[buff_offset : buff_offset + num_elems].sort()
222224
kth_idx = buff_offset + k - elems_offset
223225
if a.shape[0] % 2 == 0:
224226
# even number of elements

0 commit comments

Comments
 (0)