Skip to content

Conversation

@jerrymannil
Copy link
Collaborator

@jerrymannil jerrymannil commented Aug 22, 2025

cherry-pick of pytorch#160979
Less-performant fix until pytorch#161180 is finalized

  • The global reduction path in reduction kernel currently has two threadfence operation
  • The first threadfence is executed by all threads in all the blocks, whereas the second threadfence is only run by threads in a single block
  • For AMD gpus, threadfence is a heavy weight operation, esp. when run by all the threads in the system (due to cross-XCD synchronizations)
  • So using fine-grain fence gives significant performance boost for AMD gpus.
  • We do a release fence when threads write to reduce buffer in global memory; and then do a acquire fence when threads read from the reduce buffer

Co-author: @amd-hhashemi, @jeffdaily

Reproducer:

import torch

shapes = [(2, 896, 59, 91),
]

dims = [(2, 3),
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.bfloat16)
    x = x.to(memory_format=torch.channels_last)
    for _ in range(20):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    start_evt.record()
    for _ in range(100):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    end_evt.record()
    torch.cuda.synchronize()
    print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")

Fixes SWDEV-545710

Cherry-picked to release/2.8 branch via #2561

Cherry-picked to rocm7.1_internal_testing branch via #2563

* The global reduction path in reduction kernel currently has two threadfence operation
* The first threadfence is executed by all threads in all the blocks, whereas the second threadfence is only run by threads in a single block
* For AMD gpus, threadfence is a heavy weight operation, esp. when run by all the threads in the system (due to cross-XCD synchronizations)
* So using fine-grain fence gives significant performance boost for AMD gpus.
* We do a release fence when threads write to reduce buffer in global memory; and then do a acquire fence when threads read from the reduce buffer

Co-author: @amd-hhashemi, @jeffdaily 

**Reproducer**:
```import time
import torch

shapes = [(2, 896, 59, 91),
]

dims = [(2, 3),
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.bfloat16)
    x = x.to(memory_format=torch.channels_last)
    for _ in range(20):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    start_evt.record()
    for _ in range(100):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    end_evt.record()
    torch.cuda.synchronize()
    print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")
```
@jerrymannil jerrymannil self-assigned this Aug 22, 2025
@jerrymannil
Copy link
Collaborator Author

Results (MI300X):

Before:
Avg time for shape (2, 896, 59, 91): 82.13 us

After:
Avg time for shape (2, 896, 59, 91): 61 us

@rocm-repo-management-api
Copy link

Jenkins build for baddc98b5389ba858f9677a5a2738914e429192d commit is in progress
Links: Blue Ocean view / Build artifacts

@pruthvistony pruthvistony merged commit c00d48c into release/2.7 Aug 22, 2025
0 of 2 checks passed
@pruthvistony pruthvistony deleted the jerrymannil-patch-1 branch August 22, 2025 16:55
@jerrymannil
Copy link
Collaborator Author

! cherry-pick --onto release/2.8 rocm7.1_internal_testing

dhonnappa-amd pushed a commit that referenced this pull request Aug 22, 2025
cherry-pick of pytorch#160979
Less-performant fix until pytorch#161180
is finalized

* The global reduction path in reduction kernel currently has two
threadfence operation
* The first threadfence is executed by all threads in all the blocks,
whereas the second threadfence is only run by threads in a single block
* For AMD gpus, threadfence is a heavy weight operation, esp. when run
by all the threads in the system (due to cross-XCD synchronizations)
* So using fine-grain fence gives significant performance boost for AMD
gpus.
* We do a release fence when threads write to reduce buffer in global
memory; and then do a acquire fence when threads read from the reduce
buffer

Co-author: @amd-hhashemi, @jeffdaily 

**Reproducer**:
```import time
import torch

shapes = [(2, 896, 59, 91),
]

dims = [(2, 3),
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.bfloat16)
    x = x.to(memory_format=torch.channels_last)
    for _ in range(20):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    start_evt.record()
    for _ in range(100):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    end_evt.record()
    torch.cuda.synchronize()
    print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")
```

Fixes SWDEV-545710
dhonnappa-amd pushed a commit that referenced this pull request Aug 22, 2025
cherry-pick of pytorch#160979
Less-performant fix until pytorch#161180
is finalized

* The global reduction path in reduction kernel currently has two
threadfence operation
* The first threadfence is executed by all threads in all the blocks,
whereas the second threadfence is only run by threads in a single block
* For AMD gpus, threadfence is a heavy weight operation, esp. when run
by all the threads in the system (due to cross-XCD synchronizations)
* So using fine-grain fence gives significant performance boost for AMD
gpus.
* We do a release fence when threads write to reduce buffer in global
memory; and then do a acquire fence when threads read from the reduce
buffer

Co-author: @amd-hhashemi, @jeffdaily 

**Reproducer**:
```import time
import torch

shapes = [(2, 896, 59, 91),
]

dims = [(2, 3),
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.bfloat16)
    x = x.to(memory_format=torch.channels_last)
    for _ in range(20):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    start_evt.record()
    for _ in range(100):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    end_evt.record()
    torch.cuda.synchronize()
    print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")
```

Fixes SWDEV-545710
dhonnappa-amd pushed a commit that referenced this pull request Aug 22, 2025
cherry-pick of pytorch#160979
Less-performant fix until pytorch#161180
is finalized

* The global reduction path in reduction kernel currently has two
threadfence operation
* The first threadfence is executed by all threads in all the blocks,
whereas the second threadfence is only run by threads in a single block
* For AMD gpus, threadfence is a heavy weight operation, esp. when run
by all the threads in the system (due to cross-XCD synchronizations)
* So using fine-grain fence gives significant performance boost for AMD
gpus.
* We do a release fence when threads write to reduce buffer in global
memory; and then do a acquire fence when threads read from the reduce
buffer

Co-author: @amd-hhashemi, @jeffdaily 

**Reproducer**:
```import time
import torch

shapes = [(2, 896, 59, 91),
]

dims = [(2, 3),
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.bfloat16)
    x = x.to(memory_format=torch.channels_last)
    for _ in range(20):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    start_evt.record()
    for _ in range(100):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    end_evt.record()
    torch.cuda.synchronize()
    print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")
```

Fixes SWDEV-545710
@dhonnappa-amd
Copy link

dhonnappa-amd pushed a commit that referenced this pull request Aug 22, 2025
cherry-pick of pytorch#160979
Less-performant fix until pytorch#161180
is finalized

* The global reduction path in reduction kernel currently has two
threadfence operation
* The first threadfence is executed by all threads in all the blocks,
whereas the second threadfence is only run by threads in a single block
* For AMD gpus, threadfence is a heavy weight operation, esp. when run
by all the threads in the system (due to cross-XCD synchronizations)
* So using fine-grain fence gives significant performance boost for AMD
gpus.
* We do a release fence when threads write to reduce buffer in global
memory; and then do a acquire fence when threads read from the reduce
buffer

Co-author: @amd-hhashemi, @jeffdaily 

**Reproducer**:
```import time
import torch

shapes = [(2, 896, 59, 91),
]

dims = [(2, 3),
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.bfloat16)
    x = x.to(memory_format=torch.channels_last)
    for _ in range(20):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    torch.cuda.synchronize()

    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)
    start_evt.record()
    for _ in range(100):
        _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16)
    end_evt.record()
    torch.cuda.synchronize()
    print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")
```

Fixes SWDEV-545710
@dhonnappa-amd
Copy link

Created branch autogenerated/release/2.8_cherry-pick_pr-2553 and #2561

Created branch autogenerated/rocm7.1_internal_testing_cherry-pick_pr-2553 and #2563

Comment processed by Build

jerrymannil added a commit that referenced this pull request Aug 22, 2025
jerrymannil added a commit that referenced this pull request Aug 22, 2025
…e in reduction (#2563)

Cherry-pick of #2553

Co-authored-by: Jerry Mannil <[email protected]>
jerrymannil added a commit that referenced this pull request Sep 5, 2025
…e in reduction (#2563)

Cherry-pick of #2553

Co-authored-by: Jerry Mannil <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants