@@ -23,7 +23,7 @@ namespace at::native {
2323
2424// The maximum number of threads in a block
2525#if defined(USE_ROCM)
26- constexpr int MAX_BLOCK_SIZE = 256 ;
26+ constexpr int MAX_BLOCK_SIZE = 1024 ;
2727#else
2828constexpr int MAX_BLOCK_SIZE = 512 ;
2929#endif
@@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
3333// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
3434static int getNumThreads (int nElem) {
3535#if defined(USE_ROCM)
36- int threadSizes[5 ] = { 16 , 32 , 64 , 128 , MAX_BLOCK_SIZE };
36+ int threadSizes[5 ] = { 64 , 128 , 256 , 512 , MAX_BLOCK_SIZE };
3737#else
3838 int threadSizes[5 ] = { 32 , 64 , 128 , 256 , MAX_BLOCK_SIZE };
3939#endif
@@ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
115115 // first the reductions each thread does separately
116116 scalar_t sum = static_cast <scalar_t >(0 );
117117 for (int batch = threadIdx .y ; batch < tensor.size (0 ); batch += blockDim .y ) {
118+ #if defined(USE_ROCM)
119+ constexpr int UNRL = 4 ; // load deserilize factor
120+ scalar_t tmp[UNRL];
121+ for (int x = threadIdx .x ; x < tensor.size (2 ); x += blockDim .x *UNRL) {
122+ #pragma unroll
123+ for (int u = 0 ; u < UNRL; u++)
124+ tmp[u] = op (batch, plane, min ((int )tensor.size (2 )-1 , (int )(x+u*blockDim .x )));
125+ #pragma unroll
126+ for (int u = 0 ; u < UNRL; u++)
127+ if (x+u*blockDim .x < tensor.size (2 ))
128+ sum += tmp[u];
129+ }
130+ #else
118131 for (int x = threadIdx .x ; x < tensor.size (2 ); x += blockDim .x ) {
119132 sum += op (batch, plane, x);
120133 }
134+ #endif
121135 }
122136 __shared__ scalar_t shared[C10_WARP_SIZE];
123137 SumReduceOp<scalar_t > reduce_op;
@@ -292,13 +306,30 @@ __global__ void batch_norm_collect_statistics_kernel(
292306 stat_accscalar_t var_n = 0 ;
293307 int n = 0 ;
294308 for (int batch = threadIdx .y ; batch < input.size (0 ); batch += blockDim .y ) {
309+ #if defined(USE_ROCM)
310+ constexpr int UNRL = 4 ;
311+ stat_accscalar_t v_[UNRL];
312+ for (int x = threadIdx .x ; x < input.size (2 ); x += blockDim .x *UNRL) {
313+ for (int u = 0 ; u < UNRL; u++)
314+ v_[u] = input[batch][plane][min (x+u*blockDim .x , input.size (2 )-1 )];
315+ for (int u = 0 ; u < UNRL; u++) {
316+ if (x+u*blockDim .x < input.size (2 )) {
317+ stat_accscalar_t d1 = v_[u] - avg;
318+ n++;
319+ avg += d1 / n;
320+ var_n += d1 * (v_[u] - avg);
321+ }
322+ }
323+ }
324+ #else
295325 for (int x = threadIdx .x ; x < input.size (2 ); x += blockDim .x ) {
296326 stat_accscalar_t v = input[batch][plane][x];
297327 stat_accscalar_t d1 = v - avg;
298328 n++;
299329 avg += d1 / n;
300330 var_n += d1 * (v - avg);
301331 }
332+ #endif
302333 }
303334
304335 // first warpSum to get one value per thread to
0 commit comments