@@ -1382,7 +1382,6 @@ void launch_upsample_gen2d_aa_kernel(
1382
1382
const int interp_width = (int )ceilf (support_w) * 2 + 1 ;
1383
1383
1384
1384
auto sharedMemPerBlock = syclLocalMemSize ();
1385
- auto total_threads = syclMaxWorkItemsPerTile ();
1386
1385
int maxThreadsPerBlock = std::min<int >(
1387
1386
syclMaxWorkGroupSize<
1388
1387
UpsampleGen2dAaKernelFunctor<scalar_t , accscalar_t , InterpFilter>>(),
@@ -1395,13 +1394,9 @@ void launch_upsample_gen2d_aa_kernel(
1395
1394
int block_y = lastPow2 ((unsigned int )(numer / denom));
1396
1395
block_y = std::min<int >(maxThreadsPerBlock / block_x, block_y);
1397
1396
1398
- int grid_x = std::min<int >(
1399
- total_threads, (output_width + block_x - 1 ) / block_x * block_x);
1400
- int grid_y = std::min<int >(
1401
- total_threads / grid_x,
1402
- (output_height + block_y - 1 ) / block_y * block_y);
1403
- int grid_z =
1404
- std::min<int >(total_threads / grid_x / grid_y, nbatch * channels);
1397
+ int grid_x = (output_width + block_x - 1 ) / block_x * block_x;
1398
+ int grid_y = (output_height + block_y - 1 ) / block_y * block_y;
1399
+ int grid_z = nbatch * channels;
1405
1400
1406
1401
int64_t weights_per_block = interp_width * block_x + interp_height * block_y;
1407
1402
weights_per_block += interp_height * block_y * block_x;
@@ -1455,21 +1450,16 @@ void launch_upsample_gen2d_aa_backward_kernel(
1455
1450
auto queue = getCurrentSYCLQueue ();
1456
1451
1457
1452
auto sharedMemPerBlock = syclLocalMemSize ();
1458
- auto total_threads = syclMaxWorkItemsPerTile ();
1459
1453
int maxThreadsPerBlock = std::min<int >(
1460
1454
syclMaxWorkGroupSize<
1461
1455
UpsampleGen2dAaKernelFunctor<scalar_t , accscalar_t , InterpFilter>>(),
1462
1456
256 ); // 256 performs better
1463
1457
int block_x = syclMaxSubGroupSize ();
1464
1458
int block_y = maxThreadsPerBlock / block_x;
1465
1459
1466
- int grid_x = std::min<int >(
1467
- total_threads, (output_width + block_x - 1 ) / block_x * block_x);
1468
- int grid_y = std::min<int >(
1469
- total_threads / grid_x,
1470
- (output_height + block_y - 1 ) / block_y * block_y);
1471
- int grid_z =
1472
- std::min<int >(total_threads / grid_x / grid_y, nbatch * channels);
1460
+ int grid_x = (output_width + block_x - 1 ) / block_x * block_x;
1461
+ int grid_y = (output_height + block_y - 1 ) / block_y * block_y;
1462
+ int grid_z = nbatch * channels;
1473
1463
1474
1464
const int interp_height = (int )ceilf (support_h) * 2 + 1 ;
1475
1465
const int interp_width = (int )ceilf (support_w) * 2 + 1 ;
0 commit comments