Skip to content

Commit 9948289

Browse files
committed
[ROCm] Improve reduction sum performance
* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128 **Reproducer:** ``` import time import torch shapes = [ (5079670, 128) ] dims = [ (1) ] for i, shape in enumerate(shapes): x = torch.randn(shape, device='cuda', dtype=torch.float) for _ in range(10): w = torch.sum(x, dims[i]) torch.cuda.synchronize() print(w.size()) start_time = time.time() for _ in range(50): _ = torch.sum(x, dims[i]) torch.cuda.synchronize() end_time = time.time() mean_time = (end_time - start_time)/50 print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us") ``` **Before (MI300X):** Avg time for shape (5079670, 128): 1629.99 us **After (MI300X)** Avg time for shape (5079670, 128): 1008.59 us cherry-pick of pytorch#160466 Fixes SWDEV-546136
1 parent e1632fc commit 9948289

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

aten/src/ATen/native/cuda/Reduce.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
10581058
// In such case, values in each loaded vector always correspond to different outputs.
10591059
if (fastest_moving_stride == sizeof(scalar_t)) {
10601060
#ifdef USE_ROCM
1061-
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1) {
1061+
if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) {
10621062
#else
10631063
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) {
10641064
#endif

0 commit comments

Comments
 (0)