Skip to content

Commit a7b6c00

Browse files
Optimized BiLiear 2D Up Sampling for AMD MI devices
Cherry-pick of #2729 Co-authored-by: glen-amd <[email protected]>
1 parent 6305267 commit a7b6c00

File tree

1 file changed

+101
-2
lines changed

1 file changed

+101
-2
lines changed

aten/src/ATen/native/cuda/UpSampleBilinear2d.cu

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
127127
}
128128
}
129129

130+
#ifdef USE_ROCM
131+
// Helper function to compute output pixel range that can contribute to input pixel
132+
template <typename accscalar_t>
133+
__device__ __forceinline__ void compute_output_range(
134+
int input_pos,
135+
accscalar_t scale,
136+
int output_size,
137+
bool align_corners,
138+
int& min_output,
139+
int& max_output) {
140+
accscalar_t lo, hi;
141+
if (align_corners) {
142+
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
143+
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
144+
} else {
145+
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
146+
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
147+
}
148+
min_output = max(0, static_cast<int>(std::ceil(lo)));
149+
max_output = min(output_size - 1, static_cast<int>(std::floor(hi)));
150+
}
151+
#endif
152+
130153
// Backward (adjoint) operation 1 <- 2 (accumulates)
131154
template <typename scalar_t, typename accscalar_t>
132155
C10_LAUNCH_BOUNDS_1(1024)
@@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
141164
const bool align_corners,
142165
scalar_t* __restrict__ idata,
143166
const scalar_t* __restrict__ odata) {
144-
const size_t o_numel = nc * width2 * height2;
167+
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
145168
const size_t i_numel = nc * width1 * height1;
169+
#ifdef USE_ROCM
170+
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
171+
index += blockDim.x * gridDim.x) {
172+
// Decode input pixel coordinates
173+
size_t index_temp = index;
174+
const int w1 = index_temp % width1;
175+
index_temp /= width1;
176+
const int h1 = index_temp % height1;
177+
const size_t nc_idx = index_temp / height1;
178+
179+
accscalar_t grad_sum = 0;
180+
181+
// Find range of output pixels that could interpolate from this input pixel
182+
int h2_min, h2_max, w2_min, w2_max;
183+
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
184+
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
185+
186+
// Iterate over potential output pixels
187+
for (int h2 = h2_min; h2 <= h2_max; h2++) {
188+
for (int w2 = w2_min; w2 <= w2_max; w2++) {
189+
// Compute source coordinates for this output pixel
190+
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
191+
rheight, h2, align_corners, /*cubic=*/false);
192+
const int h1_base = (int)h1r;
193+
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
194+
const accscalar_t h1lambda = h1r - h1_base;
195+
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
196+
197+
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
198+
rwidth, w2, align_corners, /*cubic=*/false);
199+
const int w1_base = (int)w1r;
200+
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
201+
const accscalar_t w1lambda = w1r - w1_base;
202+
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
203+
204+
// Check if our input pixel participates in this interpolation and accumulate all weights
205+
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
206+
// to the same pixel, so we need to accumulate weights from all matching positions
207+
accscalar_t weight = 0;
208+
209+
// Check all four interpolation positions and accumulate weights
210+
if (h1 == h1_base && w1 == w1_base) {
211+
weight += h0lambda * w0lambda; // top-left
212+
}
213+
if (h1 == h1_base && w1 == w1_base + w1p) {
214+
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
215+
}
216+
if (h1 == h1_base + h1p && w1 == w1_base) {
217+
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
218+
}
219+
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
220+
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
221+
}
222+
223+
if (weight > 0) {
224+
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
225+
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
226+
}
227+
}
228+
}
229+
230+
// Write accumulated gradient (no atomics needed)
231+
idata[index] = static_cast<scalar_t>(grad_sum);
232+
}
233+
#else
234+
const size_t o_numel = nc * width2 * height2;
146235
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
147236
index += blockDim.x * gridDim.x) {
148237
size_t index_temp = index;
@@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
191280
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
192281
true);
193282
}
283+
#endif
194284
}
195285

196286
template <typename scalar_t, typename accscalar_t>
@@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
387477
// threads are not covering the whole input tensor.
388478
grad_input.zero_();
389479

390-
const size_t num_kernels = nbatch * channels * output_height * output_width;
391480
const int num_threads = std::min(
392481
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
393482
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
397486
return;
398487
}
399488

489+
#ifdef USE_ROCM
490+
constexpr bool use_input = true;
491+
#else
492+
constexpr bool use_input = false;
493+
#endif
494+
400495
AT_DISPATCH_FLOATING_TYPES_AND2(
401496
at::ScalarType::Half, at::ScalarType::BFloat16,
402497
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
@@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
414509
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
415510
input_width, output_width, align_corners, scales_w);
416511

512+
const size_t num_kernels = nbatch * channels * output_height * output_width;
513+
417514
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
418515
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
419516
input_height,
@@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
444541
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
445542
input_width, output_width, align_corners, scales_w);
446543

544+
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
545+
447546
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
448547
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
449548
num_threads,

0 commit comments

Comments
 (0)