Skip to content

Commit bfc7ede

Browse files
Review comments fixes
1 parent 1f3f31b commit bfc7ede

File tree

6 files changed

+39
-24
lines changed

6 files changed

+39
-24
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ repos:
5252
rev: 24.4.2
5353
hooks:
5454
- id: black
55-
args: ["--check", "--diff", "--color"]
55+
args: ["--diff", "--color"]
5656
- repo: https://github.com/pycqa/isort
5757
rev: 5.13.2
5858
hooks:

dpnp/backend/extensions/sycl_ext/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,20 @@ set(_module_src
3434
pybind11_add_module(${python_module_name} MODULE ${_module_src})
3535
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_module_src})
3636

37+
if(_dpnp_sycl_targets)
38+
# make fat binary
39+
target_compile_options(
40+
${python_module_name}
41+
PRIVATE
42+
-fsycl-targets=${_dpnp_sycl_targets}
43+
)
44+
target_link_options(
45+
${python_module_name}
46+
PRIVATE
47+
-fsycl-targets=${_dpnp_sycl_targets}
48+
)
49+
endif()
50+
3751
if (WIN32)
3852
if (${CMAKE_VERSION} VERSION_LESS "3.27")
3953
# this is a work-around for target_link_options inserting option after -link option, cause

dpnp/backend/extensions/sycl_ext/histogram.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ namespace
5050
{
5151

5252
template <typename T, typename BinsT, typename HistType = size_t>
53-
static sycl::event histogram_impl(sycl::queue exec_q,
53+
static sycl::event histogram_impl(sycl::queue &exec_q,
5454
const void *vin,
5555
const void *vbins_edges,
5656
const void *vweights,
@@ -71,15 +71,15 @@ static sycl::event histogram_impl(sycl::queue exec_q,
7171
? 256
7272
: device.get_info<sycl::info::device::max_work_group_size>();
7373

74-
uint32_t WorkPI = 128; // empirically found number
74+
constexpr uint32_t WorkPI = 128; // empirically found number
7575
auto global_size = Align(CeilDiv(size, WorkPI), local_size);
7676

7777
auto nd_range =
7878
sycl::nd_range(sycl::range<1>(global_size), sycl::range<1>(local_size));
7979

8080
return exec_q.submit([&](sycl::handler &cgh) {
8181
cgh.depends_on(depends);
82-
uint32_t dims = 1;
82+
constexpr uint32_t dims = 1;
8383

8484
auto dispatch_edges = [&](uint32_t local_mem, auto &weights,
8585
auto &hist) {
@@ -239,7 +239,7 @@ std::tuple<sycl::event, sycl::event>
239239
exec_q, {sample, bins, histogram}, {ev});
240240
}
241241

242-
return {ev, args_ev};
242+
return {args_ev, ev};
243243
}
244244

245245
std::unique_ptr<Histogram> hist;

dpnp/backend/extensions/sycl_ext/histogram.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,14 @@
2828
#include <sycl/sycl.hpp>
2929

3030
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
31-
// using dpctl::tensor::type_dispatch::num_types;
3231

3332
namespace sycl_ext
3433
{
3534
namespace histogram
3635
{
3736
struct Histogram
3837
{
39-
using FnT = sycl::event (*)(sycl::queue,
38+
using FnT = sycl::event (*)(sycl::queue &,
4039
const void *,
4140
const void *,
4241
const void *,

dpnp/backend/extensions/sycl_ext/histogram_common.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ namespace histogram
4343
{
4444

4545
template <typename N, typename D>
46-
N CeilDiv(N n, D d)
46+
constexpr auto CeilDiv(N n, D d)
4747
{
4848
return (n + d - 1) / d;
4949
}
5050

5151
template <typename N, typename D>
52-
N Align(N n, D d)
52+
constexpr auto Align(N n, D d)
5353
{
5454
return CeilDiv(n, d) * d;
5555
}

dpnp/dpnp_iface_histograms.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,12 @@
4444
import numpy
4545

4646
import dpnp
47-
import dpnp.backend.extensions.sycl_ext._sycl_ext_impl as sycl_ext # pylint: disable=C0301,E0611
47+
48+
# pylint: disable=no-name-in-module
49+
import dpnp.backend.extensions.sycl_ext._sycl_ext_impl as sycl_ext
50+
51+
# pylint: disable=no-name-in-module
52+
from .dpnp_utils import map_dtype_to_device
4853

4954
__all__ = [
5055
"digitize",
@@ -318,19 +323,14 @@ def _find_supported_dtype(dt, supported):
318323
return None
319324

320325

321-
def _result_type(dtype1, dtype2, has_fp64):
326+
def _result_type_for_device(dtype1, dtype2, device):
322327
rt = dpnp.result_type(dtype1, dtype2)
323-
if rt == dpnp.float64 and not has_fp64:
324-
return dpnp.float32
325-
326-
if rt == dpnp.complex128 and not has_fp64:
327-
return dpnp.complex64
328+
return map_dtype_to_device(rt, device)
328329

329-
return rt
330330

331-
332-
def _align_dtypes(a_dtype, bins_dtype, ntype, has_fp64):
333-
a_bin_dtype = _result_type(a_dtype, bins_dtype, has_fp64)
331+
def _align_dtypes(a_dtype, bins_dtype, ntype, device):
332+
has_fp64 = device.has_aspect_fp64
333+
a_bin_dtype = _result_type_for_device(a_dtype, bins_dtype, device)
334334

335335
supported_types = (dpnp.float32, dpnp.int64, numpy.uint64, dpnp.complex64)
336336
if has_fp64:
@@ -347,7 +347,7 @@ def _align_dtypes(a_dtype, bins_dtype, ntype, has_fp64):
347347
if (a_bin_dtype in float_types and hist_dtype in float_types) or (
348348
a_bin_dtype in complex_types and hist_dtype in complex_types
349349
):
350-
common_type = _result_type(a_bin_dtype, hist_dtype, has_fp64)
350+
common_type = _result_type_for_device(a_bin_dtype, hist_dtype, device)
351351
a_bin_dtype = common_type
352352
hist_dtype = common_type
353353

@@ -463,10 +463,10 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
463463
ntype = weights.dtype
464464

465465
queue = a.sycl_queue
466-
has_fp64 = queue.sycl_device.has_aspect_fp64
466+
device = queue.sycl_device
467467

468468
a_bin_dtype, hist_dtype = _align_dtypes(
469-
a.dtype, bin_edges.dtype, ntype, has_fp64
469+
a.dtype, bin_edges.dtype, ntype, device
470470
)
471471

472472
if a_bin_dtype is None or hist_dtype is None:
@@ -487,7 +487,10 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
487487
else None
488488
)
489489

490+
# histogram implementation uses atomics, but atomics doesn't work with
491+
# host usm memory
490492
n_usm_type = "device" if usm_type == "host" else usm_type
493+
491494
n_casted = dpnp.zeros(
492495
bin_edges.size - 1,
493496
dtype=hist_dtype,
@@ -521,7 +524,6 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
521524
n = dpnp.astype(n_casted, ntype, copy=False)
522525

523526
if density:
524-
# pylint: disable=possibly-used-before-assignment
525527
db = dpnp.diff(bin_edges).astype(
526528
dpnp.default_float_type(sycl_queue=queue)
527529
)

0 commit comments

Comments
 (0)