Skip to content

Commit e71d223

Browse files
authored
Use syclMaxWorkGroupSize instead of dpcppMaxWorkGroupSize in OP take (#4925) (#4953)
* Use syclMaxWorkGroupSize instead of dpcppMaxWorkGroupSize in OP take * fix comments
1 parent 141ac2f commit e71d223

File tree

3 files changed

+24
-9
lines changed

3 files changed

+24
-9
lines changed

csrc/gpu/aten/operators/Indexing.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,10 +1366,6 @@ void take_dpcpp(Tensor& dst, const Tensor& src, const Tensor& index) {
13661366
idx_info.collapseDims();
13671367

13681368
auto& dpcpp_queue = dpcppGetCurrentQueue();
1369-
auto dev_id = dpcppGetDeviceIdOfCurrentQueue();
1370-
auto wgroup_size = dpcppMaxWorkGroupSize(dev_id);
1371-
auto wgroup_range = (dst_num_elem + wgroup_size - 1) / wgroup_size;
1372-
13731369
auto cgf = DPCPP_Q_CGF(cgh) {
13741370
auto src_data = src.data_ptr<scalar_t>();
13751371
auto dst_data = dst.data_ptr<scalar_t>();
@@ -1384,6 +1380,8 @@ void take_dpcpp(Tensor& dst, const Tensor& src, const Tensor& index) {
13841380
src_data,
13851381
dst_data,
13861382
idx_data);
1383+
auto wgroup_size = dpcppMaxWorkGroupSize(kfn);
1384+
auto wgroup_range = (dst_num_elem + wgroup_size - 1) / wgroup_size;
13871385

13881386
cgh.parallel_for<decltype(kfn)>(
13891387
sycl::nd_range<1>({wgroup_range * wgroup_size}, {wgroup_size}), kfn);

csrc/gpu/runtime/Utils.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,28 @@ static inline sycl::queue& dpcppGetCurrentQueue() {
2121
return at::xpu::getCurrentXPUStream().queue();
2222
}
2323

24+
template <class KernelClass>
25+
static int64_t dpcppMaxWorkGroupSize(
26+
at::DeviceIndex dev_id = dpcppGetDeviceIdOfCurrentQueue()) {
27+
auto q = c10::xpu::getCurrentXPUStream(dev_id).queue();
28+
auto ctx = q.get_context();
29+
auto dev = q.get_device();
30+
31+
auto kid = ::sycl::get_kernel_id<KernelClass>();
32+
auto kbundle =
33+
::sycl::get_kernel_bundle<::sycl::bundle_state::executable>(ctx, {kid});
34+
35+
::sycl::kernel k = kbundle.get_kernel(kid);
36+
return k.get_info<::sycl::info::kernel_device_specific::work_group_size>(dev);
37+
}
38+
39+
template <class KernelClass>
40+
static int64_t dpcppMaxWorkGroupSize(
41+
KernelClass /*kfn*/,
42+
at::DeviceIndex dev_id = dpcppGetDeviceIdOfCurrentQueue()) {
43+
return dpcppMaxWorkGroupSize<KernelClass>(dev_id);
44+
}
45+
2446
static inline int64_t dpcppMaxWorkGroupSize(
2547
DeviceId dev_id = dpcppGetDeviceIdOfCurrentQueue()) {
2648
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);

tests/gpu/examples/test_take.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,12 @@
22
from torch.testing._internal.common_utils import TestCase
33

44
import intel_extension_for_pytorch # noqa
5-
import pytest
65

76
cpu_device = torch.device("cpu")
87
dpcpp_device = torch.device("xpu")
98

109

1110
class TestNNMethod(TestCase):
12-
@pytest.mark.skip(
13-
reason="PT2.5: Total number of work-items in a work-group cannot exceed 512 for this kernel \
14-
-54 (PI_ERROR_INVALID_WORK_GROUP_SIZE)"
15-
)
1611
def test_take(self, dtype=torch.float):
1712
src = torch.rand(2, 3)
1813
print(src)

0 commit comments

Comments
 (0)