You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix convolution fp16 performance drop on gfx12xx (#403)
* Remove hardcoded convolution NCHW layout assignment for fp16 precision.
* PR openxla#32773: [ROCm] Fix convolution fp16 performance drop on gfx11xx, gfx12xx
Imported from GitHub PR openxla#32773
📝 Summary of Changes
Remove hardcoded NHWC convolution layout for fp16 precision.
🎯 Justification
Performance drops for fp16 precision on gfx11xx and gfx12xx GPUs were observed internally, as well as by the [community](jax-ml/jax#30548).
🚀 Kind of Contribution
🐛 Bug Fix
📊 Benchmark
Community member provided the script with whom the [profiling can be done](jax-ml/jax#30548 (comment)).
Significant performance improvement for fp16 on gfx12xx:
```
Running on: rocm:0
Testing float32...
Avg time: 0.092307 s, Throughput: 1.68 TFLOP/s
Testing float16...
Avg time: 0.011742 s, Throughput: 13.17 TFLOP/s
Testing bfloat16...
Avg time: 0.011989 s, Throughput: 12.90 TFLOP/s
```
Results of the profiling before the fix:
```
Running on: rocm:0
Testing float32...
Avg time: 0.092312 s, Throughput: 1.67 TFLOP/s
Testing float16...
Avg time: 0.775142 s, Throughput: 0.20 TFLOP/s
Testing bfloat16...
Avg time: 0.011990 s, Throughput: 12.90 TFLOP/s
```
@xla-rotation can you please review this PR?
Copybara import of the project:
--
c9fdba7 by Aleksa Arsic <[email protected]>:
Remove hardcoded convolution NCHW layout assignment for fp16 precision.
--
69660d1 by Aleksa Arsic <[email protected]>:
Add unit tests for ROCm layout assignment.
Merging this change closesopenxla#32773
COPYBARA_INTEGRATE_REVIEW=openxla#32773 from ROCm:ci_fix-hardcoded-NHWC-conv-layout-for-fp16 69660d1
PiperOrigin-RevId: 822022522
0 commit comments