Skip to content

Commit 21bbe8e

Browse files
committed
restore format
1 parent a6f3aca commit 21bbe8e

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

ggml/src/ggml-sycl/im2col.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
//
12
// MIT license
23
// Copyright (C) 2024 Intel Corporation
34
// SPDX-License-Identifier: MIT
5+
//
6+
47
//
58
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
69
// See https://llvm.org/LICENSE.txt for license information.
@@ -17,21 +20,21 @@
1720
template <typename T>
1821
static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_t offset_delta, int64_t IC, int64_t IW,
1922
int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
20-
int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ctl) {
21-
const int64_t work_group_size = item_ctl.get_local_range(2);
22-
const int64_t global_id = item_ctl.get_local_id(2) + (work_group_size * item_ctl.get_group(2));
23+
int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ct1) {
24+
const int64_t work_group_size = item_ct1.get_local_range(2);
25+
const int64_t global_id = item_ct1.get_local_id(2) + (work_group_size * item_ct1.get_group(2));
2326

2427
// make each work-item deal with more elements since sycl global range can not exceed max int
25-
for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ctl.get_group_range(2))) {
28+
for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) {
2629
const int64_t ksize = OW * (KH > 1 ? KW : 1);
2730
const int64_t kx = i / ksize;
2831
const int64_t kd = kx * ksize;
2932
const int64_t ky = (i - kd) / OW;
3033
const int64_t ix = i % OW;
3134

32-
const int64_t oh = item_ctl.get_group(1);
33-
const int64_t batch = item_ctl.get_group(0) / IC;
34-
const int64_t ic = item_ctl.get_group(0) % IC;
35+
const int64_t oh = item_ct1.get_group(1);
36+
const int64_t batch = item_ct1.get_group(0) / IC;
37+
const int64_t ic = item_ct1.get_group(0) % IC;
3538

3639
const int64_t iiw = (ix * s0) + (kx * d0) - p0;
3740
const int64_t iih = (oh * s1) + (ky * d1) - p1;
@@ -67,9 +70,9 @@ static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t I
6770

6871
const int64_t CHW = IC * KH * KW;
6972

70-
stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item) {
73+
stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {
7174
im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1,
72-
p0, p1, d0, d1, item);
75+
p0, p1, d0, d1, item_ct1);
7376
});
7477
}
7578

0 commit comments

Comments
 (0)