Skip to content

Commit b979b55

Browse files
authored
UpsampleBilinear2Daa: Correct the global range (#1942)
Correct the global range. The global range should be multiples of work group size. Fix #1465
1 parent f2bcd8a commit b979b55

File tree

1 file changed

+6
-16
lines changed

1 file changed

+6
-16
lines changed

src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,6 @@ void launch_upsample_gen2d_aa_kernel(
13821382
const int interp_width = (int)ceilf(support_w) * 2 + 1;
13831383

13841384
auto sharedMemPerBlock = syclLocalMemSize();
1385-
auto total_threads = syclMaxWorkItemsPerTile();
13861385
int maxThreadsPerBlock = std::min<int>(
13871386
syclMaxWorkGroupSize<
13881387
UpsampleGen2dAaKernelFunctor<scalar_t, accscalar_t, InterpFilter>>(),
@@ -1395,13 +1394,9 @@ void launch_upsample_gen2d_aa_kernel(
13951394
int block_y = lastPow2((unsigned int)(numer / denom));
13961395
block_y = std::min<int>(maxThreadsPerBlock / block_x, block_y);
13971396

1398-
int grid_x = std::min<int>(
1399-
total_threads, (output_width + block_x - 1) / block_x * block_x);
1400-
int grid_y = std::min<int>(
1401-
total_threads / grid_x,
1402-
(output_height + block_y - 1) / block_y * block_y);
1403-
int grid_z =
1404-
std::min<int>(total_threads / grid_x / grid_y, nbatch * channels);
1397+
int grid_x = (output_width + block_x - 1) / block_x * block_x;
1398+
int grid_y = (output_height + block_y - 1) / block_y * block_y;
1399+
int grid_z = nbatch * channels;
14051400

14061401
int64_t weights_per_block = interp_width * block_x + interp_height * block_y;
14071402
weights_per_block += interp_height * block_y * block_x;
@@ -1455,21 +1450,16 @@ void launch_upsample_gen2d_aa_backward_kernel(
14551450
auto queue = getCurrentSYCLQueue();
14561451

14571452
auto sharedMemPerBlock = syclLocalMemSize();
1458-
auto total_threads = syclMaxWorkItemsPerTile();
14591453
int maxThreadsPerBlock = std::min<int>(
14601454
syclMaxWorkGroupSize<
14611455
UpsampleGen2dAaKernelFunctor<scalar_t, accscalar_t, InterpFilter>>(),
14621456
256); // 256 performs better
14631457
int block_x = syclMaxSubGroupSize();
14641458
int block_y = maxThreadsPerBlock / block_x;
14651459

1466-
int grid_x = std::min<int>(
1467-
total_threads, (output_width + block_x - 1) / block_x * block_x);
1468-
int grid_y = std::min<int>(
1469-
total_threads / grid_x,
1470-
(output_height + block_y - 1) / block_y * block_y);
1471-
int grid_z =
1472-
std::min<int>(total_threads / grid_x / grid_y, nbatch * channels);
1460+
int grid_x = (output_width + block_x - 1) / block_x * block_x;
1461+
int grid_y = (output_height + block_y - 1) / block_y * block_y;
1462+
int grid_z = nbatch * channels;
14731463

14741464
const int interp_height = (int)ceilf(support_h) * 2 + 1;
14751465
const int interp_width = (int)ceilf(support_w) * 2 + 1;

0 commit comments

Comments
 (0)