Skip to content

Commit 79e71b1

Browse files
committed
Merge branch 'release/2.9' into vmijovic/add_support_hipblaslt_gfx1151
2 parents 151b746 + 25a49ce commit 79e71b1

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

aten/src/ATen/native/cuda/Normalization.cuh

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace at::native {
2323

2424
// The maximum number of threads in a block
2525
#if defined(USE_ROCM)
26-
constexpr int MAX_BLOCK_SIZE = 256;
26+
constexpr int MAX_BLOCK_SIZE = 1024;
2727
#else
2828
constexpr int MAX_BLOCK_SIZE = 512;
2929
#endif
@@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
3333
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
3434
static int getNumThreads(int nElem) {
3535
#if defined(USE_ROCM)
36-
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
36+
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
3737
#else
3838
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
3939
#endif
@@ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
115115
// first the reductions each thread does separately
116116
scalar_t sum = static_cast<scalar_t>(0);
117117
for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
118+
#if defined(USE_ROCM)
119+
constexpr int UNRL = 4; // load deserilize factor
120+
scalar_t tmp[UNRL];
121+
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) {
122+
#pragma unroll
123+
for (int u = 0; u < UNRL; u++)
124+
tmp[u] = op(batch, plane, min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
125+
#pragma unroll
126+
for (int u = 0; u < UNRL; u++)
127+
if (x+u*blockDim.x < tensor.size(2))
128+
sum += tmp[u];
129+
}
130+
#else
118131
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
119132
sum += op(batch, plane, x);
120133
}
134+
#endif
121135
}
122136
__shared__ scalar_t shared[C10_WARP_SIZE];
123137
SumReduceOp<scalar_t> reduce_op;
@@ -292,13 +306,30 @@ __global__ void batch_norm_collect_statistics_kernel(
292306
stat_accscalar_t var_n = 0;
293307
int n = 0;
294308
for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
309+
#if defined(USE_ROCM)
310+
constexpr int UNRL = 4;
311+
stat_accscalar_t v_[UNRL];
312+
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) {
313+
for (int u = 0; u < UNRL; u++)
314+
v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)];
315+
for (int u = 0; u < UNRL; u++) {
316+
if (x+u*blockDim.x < input.size(2)) {
317+
stat_accscalar_t d1 = v_[u] - avg;
318+
n++;
319+
avg += d1 / n;
320+
var_n += d1 * (v_[u] - avg);
321+
}
322+
}
323+
}
324+
#else
295325
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
296326
stat_accscalar_t v = input[batch][plane][x];
297327
stat_accscalar_t d1 = v - avg;
298328
n++;
299329
avg += d1 / n;
300330
var_n += d1 * (v - avg);
301331
}
332+
#endif
302333
}
303334

304335
// first warpSum to get one value per thread to

cmake/Modules/FindOpenBLAS.cmake

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,15 @@ SET(Open_BLAS_LIB_SEARCH_PATHS
2929
$ENV{OpenBLAS}/lib
3030
$ENV{OpenBLAS_HOME}
3131
$ENV{OpenBLAS_HOME}/lib
32-
)
32+
)
33+
34+
SET(Open_BLAS_LIB_NAME openblas)
35+
IF(DEFINED ENV{OpenBLAS_LIB_NAME})
36+
SET(Open_BLAS_LIB_NAME $ENV{OpenBLAS_LIB_NAME})
37+
ENDIF()
3338

3439
FIND_PATH(OpenBLAS_INCLUDE_DIR NAMES cblas.h PATHS ${Open_BLAS_INCLUDE_SEARCH_PATHS})
35-
FIND_LIBRARY(OpenBLAS_LIB NAMES openblas PATHS ${Open_BLAS_LIB_SEARCH_PATHS})
40+
FIND_LIBRARY(OpenBLAS_LIB NAMES ${Open_BLAS_LIB_NAME} PATHS ${Open_BLAS_LIB_SEARCH_PATHS})
3641

3742
SET(OpenBLAS_FOUND ON)
3843

0 commit comments

Comments
 (0)