1717template  <typename  T>
1818static  void  im2col_kernel (const  float  * x, T * dst, int64_t  batch_offset, int64_t  offset_delta, int64_t  IC, int64_t  IW,
1919                          int64_t  IH, int64_t  OH, int64_t  OW, int64_t  KW, int64_t  KH, int64_t  pelements, int64_t  CHW,
20-                           int  s0, int  s1, int  p0, int  p1, int  d0, int  d1, const  sycl::nd_item<3 > & item) {
21-     const  int64_t  work_group_size_x = item.get_local_range (2 );
22-     const  int64_t  total_threads_x   = work_group_size_x * item.get_group_range (2 );
23-     const  int64_t  global_id_x       = item.get_global_id (2 );
20+                           int  s0, int  s1, int  p0, int  p1, int  d0, int  d1, const  sycl::nd_item<3 > & item_ctl) {
21+     const  int64_t  work_group_size = item_ctl.get_local_range (2 );
22+     const  int64_t  global_id       = item_ctl.get_local_id (2 ) + (work_group_size * item_ctl.get_group (2 ));
2423
25-     for  (int64_t  i = global_id_x; i < pelements; i += total_threads_x) {
24+     //  make each work-item deal with more elements since sycl global range can not exceed max int
25+     for  (int64_t  i = global_id; i < pelements; i += (work_group_size * item_ctl.get_group_range (2 ))) {
2626        const  int64_t  ksize = OW * (KH > 1  ? KW : 1 );
2727        const  int64_t  kx    = i / ksize;
2828        const  int64_t  kd    = kx * ksize;
2929        const  int64_t  ky    = (i - kd) / OW;
3030        const  int64_t  ix    = i % OW;
3131
32-         const  int64_t  oh      = item.get_group (1 );
33-         const  int64_t  group_z = item.get_group (0 );
34-         const  int64_t  batch   = group_z / IC;
35-         const  int64_t  ic      = group_z % IC;
32+         const  int64_t  oh    = item_ctl.get_group (1 );
33+         const  int64_t  batch = item_ctl.get_group (0 ) / IC;
34+         const  int64_t  ic    = item_ctl.get_group (0 ) % IC;
3635
3736        const  int64_t  iiw = (ix * s0) + (kx * d0) - p0;
3837        const  int64_t  iih = (oh * s1) + (ky * d1) - p1;
@@ -58,11 +57,13 @@ static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t I
5857                                 int64_t  KH, int64_t  IC, int64_t  batch, int64_t  batch_offset, int64_t  offset_delta,
5958                                 int  s0, int  s1, int  p0, int  p1, int  d0, int  d1, queue_ptr stream) {
6059    const  int64_t  parallel_elements = OW * KW * KH;
61-     const  int64_t  block_size_x      = SYCL_IM2COL_BLOCK_SIZE;
62-     const  int64_t  num_groups_x      = (parallel_elements + block_size_x - 1 ) / block_size_x;
60+     const  int64_t  num_blocks        = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1 ) / SYCL_IM2COL_BLOCK_SIZE;
6361
64-     sycl::range<3 > block_nums (batch * IC, OH, num_groups_x);
65-     sycl::range<3 > local_range (1 , 1 , block_size_x);
62+     //  decrease global range when it exceeds the max int
63+     int64_t  local_size = downsample_sycl_global_range (batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE);
64+ 
65+     sycl::range<3 > block_nums (batch * IC, OH, num_blocks);
66+     sycl::range<3 > local_range (1 , 1 , local_size);
6667
6768    const  int64_t  CHW = IC * KH * KW;
6869
@@ -130,4 +131,3 @@ void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
130131                        batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
131132    }
132133}
133- 
0 commit comments