|
| 1 | +// |
1 | 2 | // MIT license |
2 | 3 | // Copyright (C) 2024 Intel Corporation |
3 | 4 | // SPDX-License-Identifier: MIT |
| 5 | +// |
| 6 | + |
4 | 7 | // |
5 | 8 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
6 | 9 | // See https://llvm.org/LICENSE.txt for license information. |
|
17 | 20 | template <typename T> |
18 | 21 | static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_t offset_delta, int64_t IC, int64_t IW, |
19 | 22 | int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW, |
20 | | - int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ctl) { |
21 | | - const int64_t work_group_size = item_ctl.get_local_range(2); |
22 | | - const int64_t global_id = item_ctl.get_local_id(2) + (work_group_size * item_ctl.get_group(2)); |
| 23 | + int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ct1) { |
| 24 | + const int64_t work_group_size = item_ct1.get_local_range(2); |
| 25 | + const int64_t global_id = item_ct1.get_local_id(2) + (work_group_size * item_ct1.get_group(2)); |
23 | 26 |
|
24 | 27 | // make each work-item deal with more elements since sycl global range can not exceed max int |
25 | | - for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ctl.get_group_range(2))) { |
| 28 | + for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) { |
26 | 29 | const int64_t ksize = OW * (KH > 1 ? KW : 1); |
27 | 30 | const int64_t kx = i / ksize; |
28 | 31 | const int64_t kd = kx * ksize; |
29 | 32 | const int64_t ky = (i - kd) / OW; |
30 | 33 | const int64_t ix = i % OW; |
31 | 34 |
|
32 | | - const int64_t oh = item_ctl.get_group(1); |
33 | | - const int64_t batch = item_ctl.get_group(0) / IC; |
34 | | - const int64_t ic = item_ctl.get_group(0) % IC; |
| 35 | + const int64_t oh = item_ct1.get_group(1); |
| 36 | + const int64_t batch = item_ct1.get_group(0) / IC; |
| 37 | + const int64_t ic = item_ct1.get_group(0) % IC; |
35 | 38 |
|
36 | 39 | const int64_t iiw = (ix * s0) + (kx * d0) - p0; |
37 | 40 | const int64_t iih = (oh * s1) + (ky * d1) - p1; |
@@ -67,9 +70,9 @@ static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t I |
67 | 70 |
|
68 | 71 | const int64_t CHW = IC * KH * KW; |
69 | 72 |
|
70 | | - stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item) { |
| 73 | + stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) { |
71 | 74 | im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1, |
72 | | - p0, p1, d0, d1, item); |
| 75 | + p0, p1, d0, d1, item_ct1); |
73 | 76 | }); |
74 | 77 | } |
75 | 78 |
|
|
0 commit comments