Commit 9948289
committed
[ROCm] Improve reduction sum performance
* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128
**Reproducer:**
```
import time
import torch
shapes = [
(5079670, 128)
]
dims = [
(1)
]
for i, shape in enumerate(shapes):
x = torch.randn(shape, device='cuda', dtype=torch.float)
for _ in range(10):
w = torch.sum(x, dims[i])
torch.cuda.synchronize()
print(w.size())
start_time = time.time()
for _ in range(50):
_ = torch.sum(x, dims[i])
torch.cuda.synchronize()
end_time = time.time()
mean_time = (end_time - start_time)/50
print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```
**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us
**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us
cherry-pick of pytorch#160466
Fixes SWDEV-5461361 parent e1632fc commit 9948289
1 file changed
+1
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1058 | 1058 | | |
1059 | 1059 | | |
1060 | 1060 | | |
1061 | | - | |
| 1061 | + | |
1062 | 1062 | | |
1063 | 1063 | | |
1064 | 1064 | | |
| |||
0 commit comments