Skip to content

Commit c22226f

Browse files
committed
Merge branch 'release/2.8' into vmijovic/change_default_to_hipblaslt
2 parents f1e6ec9 + 3658645 commit c22226f

File tree

6 files changed

+151
-7
lines changed

6 files changed

+151
-7
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
f9e5bf54a2fe1a6262a41b27b38180cdb6fae6a2
1+
21876a4bbaf371bcb83df8e6ee4f43a92f524dfe

aten/src/ATen/Context.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@ at::BlasBackend Context::blasPreferredBackend() {
341341
#if ROCM_VERSION >= 60402
342342
"gfx1150", "gfx1151",
343343
#endif
344+
#if ROCM_VERSION >= 60402
345+
"gfx1150", "gfx1151",
346+
#endif
344347
#if ROCM_VERSION >= 60500
345348
"gfx950"
346349
#endif
@@ -370,6 +373,9 @@ at::BlasBackend Context::blasPreferredBackend() {
370373
#if ROCM_VERSION >= 60402
371374
"gfx1150", "gfx1151",
372375
#endif
376+
#if ROCM_VERSION >= 60402
377+
"gfx1150", "gfx1151",
378+
#endif
373379
#if ROCM_VERSION >= 60500
374380
"gfx950"
375381
#endif

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,9 @@ static bool isSupportedHipLtROCmArch(int index) {
275275
#if ROCM_VERSION >= 60402
276276
"gfx1150", "gfx1151",
277277
#endif
278+
#if ROCM_VERSION >= 60402
279+
"gfx1150", "gfx1151",
280+
#endif
278281
#if ROCM_VERSION >= 60500
279282
"gfx950"
280283
#endif

aten/src/ATen/native/cuda/Normalization.cuh

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace at::native {
2323

2424
// The maximum number of threads in a block
2525
#if defined(USE_ROCM)
26-
constexpr int MAX_BLOCK_SIZE = 256;
26+
constexpr int MAX_BLOCK_SIZE = 1024;
2727
#else
2828
constexpr int MAX_BLOCK_SIZE = 512;
2929
#endif
@@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
3333
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
3434
static int getNumThreads(int nElem) {
3535
#if defined(USE_ROCM)
36-
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
36+
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
3737
#else
3838
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
3939
#endif
@@ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
115115
// first the reductions each thread does separately
116116
scalar_t sum = static_cast<scalar_t>(0);
117117
for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
118+
#if defined(USE_ROCM)
119+
constexpr int UNRL = 4; // load deserilize factor
120+
scalar_t tmp[UNRL];
121+
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) {
122+
#pragma unroll
123+
for (int u = 0; u < UNRL; u++)
124+
tmp[u] = op(batch, plane, min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
125+
#pragma unroll
126+
for (int u = 0; u < UNRL; u++)
127+
if (x+u*blockDim.x < tensor.size(2))
128+
sum += tmp[u];
129+
}
130+
#else
118131
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
119132
sum += op(batch, plane, x);
120133
}
134+
#endif
121135
}
122136
__shared__ scalar_t shared[C10_WARP_SIZE];
123137
SumReduceOp<scalar_t> reduce_op;
@@ -292,13 +306,30 @@ __global__ void batch_norm_collect_statistics_kernel(
292306
stat_accscalar_t var_n = 0;
293307
int n = 0;
294308
for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
309+
#if defined(USE_ROCM)
310+
constexpr int UNRL = 4;
311+
stat_accscalar_t v_[UNRL];
312+
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) {
313+
for (int u = 0; u < UNRL; u++)
314+
v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)];
315+
for (int u = 0; u < UNRL; u++) {
316+
if (x+u*blockDim.x < input.size(2)) {
317+
stat_accscalar_t d1 = v_[u] - avg;
318+
n++;
319+
avg += d1 / n;
320+
var_n += d1 * (v_[u] - avg);
321+
}
322+
}
323+
}
324+
#else
295325
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
296326
stat_accscalar_t v = input[batch][plane][x];
297327
stat_accscalar_t d1 = v - avg;
298328
n++;
299329
avg += d1 / n;
300330
var_n += d1 * (v - avg);
301331
}
332+
#endif
302333
}
303334

304335
// first warpSum to get one value per thread to

aten/src/ATen/native/cuda/UpSampleBilinear2d.cu

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
127127
}
128128
}
129129

130+
#ifdef USE_ROCM
131+
// Helper function to compute output pixel range that can contribute to input pixel
132+
template <typename accscalar_t>
133+
__device__ __forceinline__ void compute_output_range(
134+
int input_pos,
135+
accscalar_t scale,
136+
int output_size,
137+
bool align_corners,
138+
int& min_output,
139+
int& max_output) {
140+
accscalar_t lo, hi;
141+
if (align_corners) {
142+
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
143+
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
144+
} else {
145+
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
146+
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
147+
}
148+
min_output = max(0, static_cast<int>(std::ceil(lo)));
149+
max_output = min(output_size - 1, static_cast<int>(std::floor(hi)));
150+
}
151+
#endif
152+
130153
// Backward (adjoint) operation 1 <- 2 (accumulates)
131154
template <typename scalar_t, typename accscalar_t>
132155
C10_LAUNCH_BOUNDS_1(1024)
@@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
141164
const bool align_corners,
142165
scalar_t* __restrict__ idata,
143166
const scalar_t* __restrict__ odata) {
144-
const size_t o_numel = nc * width2 * height2;
167+
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
145168
const size_t i_numel = nc * width1 * height1;
169+
#ifdef USE_ROCM
170+
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
171+
index += blockDim.x * gridDim.x) {
172+
// Decode input pixel coordinates
173+
size_t index_temp = index;
174+
const int w1 = index_temp % width1;
175+
index_temp /= width1;
176+
const int h1 = index_temp % height1;
177+
const size_t nc_idx = index_temp / height1;
178+
179+
accscalar_t grad_sum = 0;
180+
181+
// Find range of output pixels that could interpolate from this input pixel
182+
int h2_min, h2_max, w2_min, w2_max;
183+
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
184+
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
185+
186+
// Iterate over potential output pixels
187+
for (int h2 = h2_min; h2 <= h2_max; h2++) {
188+
for (int w2 = w2_min; w2 <= w2_max; w2++) {
189+
// Compute source coordinates for this output pixel
190+
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
191+
rheight, h2, align_corners, /*cubic=*/false);
192+
const int h1_base = (int)h1r;
193+
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
194+
const accscalar_t h1lambda = h1r - h1_base;
195+
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
196+
197+
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
198+
rwidth, w2, align_corners, /*cubic=*/false);
199+
const int w1_base = (int)w1r;
200+
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
201+
const accscalar_t w1lambda = w1r - w1_base;
202+
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
203+
204+
// Check if our input pixel participates in this interpolation and accumulate all weights
205+
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
206+
// to the same pixel, so we need to accumulate weights from all matching positions
207+
accscalar_t weight = 0;
208+
209+
// Check all four interpolation positions and accumulate weights
210+
if (h1 == h1_base && w1 == w1_base) {
211+
weight += h0lambda * w0lambda; // top-left
212+
}
213+
if (h1 == h1_base && w1 == w1_base + w1p) {
214+
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
215+
}
216+
if (h1 == h1_base + h1p && w1 == w1_base) {
217+
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
218+
}
219+
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
220+
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
221+
}
222+
223+
if (weight > 0) {
224+
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
225+
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
226+
}
227+
}
228+
}
229+
230+
// Write accumulated gradient (no atomics needed)
231+
idata[index] = static_cast<scalar_t>(grad_sum);
232+
}
233+
#else
234+
const size_t o_numel = nc * width2 * height2;
146235
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
147236
index += blockDim.x * gridDim.x) {
148237
size_t index_temp = index;
@@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
191280
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
192281
true);
193282
}
283+
#endif
194284
}
195285

196286
template <typename scalar_t, typename accscalar_t>
@@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
387477
// threads are not covering the whole input tensor.
388478
grad_input.zero_();
389479

390-
const size_t num_kernels = nbatch * channels * output_height * output_width;
391480
const int num_threads = std::min(
392481
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
393482
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
397486
return;
398487
}
399488

489+
#ifdef USE_ROCM
490+
constexpr bool use_input = true;
491+
#else
492+
constexpr bool use_input = false;
493+
#endif
494+
400495
AT_DISPATCH_FLOATING_TYPES_AND2(
401496
at::ScalarType::Half, at::ScalarType::BFloat16,
402497
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
@@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
414509
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
415510
input_width, output_width, align_corners, scales_w);
416511

512+
const size_t num_kernels = nbatch * channels * output_height * output_width;
513+
417514
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
418515
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
419516
input_height,
@@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
444541
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
445542
input_width, output_width, align_corners, scales_w);
446543

544+
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
545+
447546
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
448547
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
449548
num_threads,

cmake/Modules/FindOpenBLAS.cmake

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,15 @@ SET(Open_BLAS_LIB_SEARCH_PATHS
2929
$ENV{OpenBLAS}/lib
3030
$ENV{OpenBLAS_HOME}
3131
$ENV{OpenBLAS_HOME}/lib
32-
)
32+
)
33+
34+
SET(Open_BLAS_LIB_NAME openblas)
35+
IF(DEFINED ENV{OpenBLAS_LIB_NAME})
36+
SET(Open_BLAS_LIB_NAME $ENV{OpenBLAS_LIB_NAME})
37+
ENDIF()
3338

3439
FIND_PATH(OpenBLAS_INCLUDE_DIR NAMES cblas.h PATHS ${Open_BLAS_INCLUDE_SEARCH_PATHS})
35-
FIND_LIBRARY(OpenBLAS_LIB NAMES openblas PATHS ${Open_BLAS_LIB_SEARCH_PATHS})
40+
FIND_LIBRARY(OpenBLAS_LIB NAMES ${Open_BLAS_LIB_NAME} PATHS ${Open_BLAS_LIB_SEARCH_PATHS})
3641

3742
SET(OpenBLAS_FOUND ON)
3843

0 commit comments

Comments
 (0)