@@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
127127 }
128128}
129129
130+ #ifdef USE_ROCM
131+ // Helper function to compute output pixel range that can contribute to input pixel
132+ template <typename accscalar_t >
133+ __device__ __forceinline__ void compute_output_range (
134+ int input_pos,
135+ accscalar_t scale,
136+ int output_size,
137+ bool align_corners,
138+ int & min_output,
139+ int & max_output) {
140+ accscalar_t lo, hi;
141+ if (align_corners) {
142+ lo = static_cast <accscalar_t >(input_pos - 1 ) / scale;
143+ hi = static_cast <accscalar_t >(input_pos + 1 ) / scale;
144+ } else {
145+ lo = (input_pos - static_cast <accscalar_t >(0.5 )) / scale - static_cast <accscalar_t >(0.5 );
146+ hi = (input_pos + static_cast <accscalar_t >(1.5 )) / scale - static_cast <accscalar_t >(0.5 );
147+ }
148+ min_output = max (0 , static_cast <int >(std::ceil (lo)));
149+ max_output = min (output_size - 1 , static_cast <int >(std::floor (hi)));
150+ }
151+ #endif
152+
130153// Backward (adjoint) operation 1 <- 2 (accumulates)
131154template <typename scalar_t , typename accscalar_t >
132155C10_LAUNCH_BOUNDS_1 (1024 )
@@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
141164 const bool align_corners,
142165 scalar_t * __restrict__ idata,
143166 const scalar_t * __restrict__ odata) {
144- const size_t o_numel = nc * width2 * height2;
167+ // In C++, integer multiplication, like in standard arithmetic, is generally commutative.
145168 const size_t i_numel = nc * width1 * height1;
169+ #ifdef USE_ROCM
170+ for (size_t index = blockDim .x * blockIdx .x + threadIdx .x ; index < i_numel;
171+ index += blockDim .x * gridDim .x ) {
172+ // Decode input pixel coordinates
173+ size_t index_temp = index;
174+ const int w1 = index_temp % width1;
175+ index_temp /= width1;
176+ const int h1 = index_temp % height1;
177+ const size_t nc_idx = index_temp / height1;
178+
179+ accscalar_t grad_sum = 0 ;
180+
181+ // Find range of output pixels that could interpolate from this input pixel
182+ int h2_min, h2_max, w2_min, w2_max;
183+ compute_output_range<accscalar_t >(h1, rheight, height2, align_corners, h2_min, h2_max);
184+ compute_output_range<accscalar_t >(w1, rwidth, width2, align_corners, w2_min, w2_max);
185+
186+ // Iterate over potential output pixels
187+ for (int h2 = h2_min; h2 <= h2_max; h2++) {
188+ for (int w2 = w2_min; w2 <= w2_max; w2++) {
189+ // Compute source coordinates for this output pixel
190+ const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t >(
191+ rheight, h2, align_corners, /* cubic=*/ false );
192+ const int h1_base = (int )h1r;
193+ const int h1p = (h1_base < height1 - 1 ) ? 1 : 0 ;
194+ const accscalar_t h1lambda = h1r - h1_base;
195+ const accscalar_t h0lambda = static_cast <accscalar_t >(1 ) - h1lambda;
196+
197+ const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t >(
198+ rwidth, w2, align_corners, /* cubic=*/ false );
199+ const int w1_base = (int )w1r;
200+ const int w1p = (w1_base < width1 - 1 ) ? 1 : 0 ;
201+ const accscalar_t w1lambda = w1r - w1_base;
202+ const accscalar_t w0lambda = static_cast <accscalar_t >(1 ) - w1lambda;
203+
204+ // Check if our input pixel participates in this interpolation and accumulate all weights
205+ // At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
206+ // to the same pixel, so we need to accumulate weights from all matching positions
207+ accscalar_t weight = 0 ;
208+
209+ // Check all four interpolation positions and accumulate weights
210+ if (h1 == h1_base && w1 == w1_base) {
211+ weight += h0lambda * w0lambda; // top-left
212+ }
213+ if (h1 == h1_base && w1 == w1_base + w1p) {
214+ weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
215+ }
216+ if (h1 == h1_base + h1p && w1 == w1_base) {
217+ weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
218+ }
219+ if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
220+ weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
221+ }
222+
223+ if (weight > 0 ) {
224+ const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
225+ grad_sum += weight * static_cast <accscalar_t >(odata[output_idx]);
226+ }
227+ }
228+ }
229+
230+ // Write accumulated gradient (no atomics needed)
231+ idata[index] = static_cast <scalar_t >(grad_sum);
232+ }
233+ #else
234+ const size_t o_numel = nc * width2 * height2;
146235 for (size_t index = blockDim .x * blockIdx .x + threadIdx .x ; index < o_numel;
147236 index += blockDim .x * gridDim .x ) {
148237 size_t index_temp = index;
@@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
191280 static_cast <scalar_t >(h1lambda * w1lambda * d2val),
192281 true );
193282 }
283+ #endif
194284}
195285
196286template <typename scalar_t , typename accscalar_t >
@@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
387477 // threads are not covering the whole input tensor.
388478 grad_input.zero_ ();
389479
390- const size_t num_kernels = nbatch * channels * output_height * output_width;
391480 const int num_threads = std::min (
392481 at::cuda::getCurrentDeviceProperties ()->maxThreadsPerBlock , 1024 );
393482 cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
@@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
397486 return ;
398487 }
399488
489+ #ifdef USE_ROCM
490+ constexpr bool use_input = true ;
491+ #else
492+ constexpr bool use_input = false ;
493+ #endif
494+
400495 AT_DISPATCH_FLOATING_TYPES_AND2 (
401496 at::ScalarType::Half, at::ScalarType::BFloat16,
402497 grad_output_.scalar_type (), " upsample_bilinear2d_backward_out_frame" , [&] {
@@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
414509 const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t >(
415510 input_width, output_width, align_corners, scales_w);
416511
512+ const size_t num_kernels = nbatch * channels * output_height * output_width;
513+
417514 upsample_bilinear2d_backward_nhwc_out_frame<scalar_t , accscalar_t >
418515 <<<ceil_div(num_kernels, static_cast <size_t >(num_threads)), num_threads, 0 , stream>>> (
419516 input_height,
@@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
444541 const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t >(
445542 input_width, output_width, align_corners, scales_w);
446543
544+ const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
545+
447546 upsample_bilinear2d_backward_out_frame<scalar_t , accscalar_t >
448547 <<<ceil_div(num_kernels, static_cast <size_t >(num_threads)),
449548 num_threads,
0 commit comments