-
Notifications
You must be signed in to change notification settings - Fork 22
GEMM reference computation offload #392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
44df11e to
e60b912
Compare
e60b912 to
557d580
Compare
557d580 to
ad748da
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces a GPU-accelerated implementation of the GEMM reference computation using HIP/CUDA kernels to improve performance over the previous CPU-based implementation. The reference computation is critical for validating GEMM operations, and offloading it to the GPU significantly speeds up testing.
Key Changes
- Replaced CPU OpenMP-based reference GEMM computation with GPU kernel implementation
- Introduced
compute_ref_kernelto perform matrix multiplication, bias addition, GELU activation, and scaling on GPU - Refactored both tensor-wise and MXFP8 code paths to use a common
compute_ref_implfunction that manages device memory and kernel execution
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
alextmagro
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Matthias! Looks great, just have a couple performance questions
dbf7ae9 to
11e090b
Compare
|
|
||
| float val = 0.0f; | ||
|
|
||
| if (in_range) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to continue the subsequent blockwise reduction if not in_range?
| const size_t lenBias = m; | ||
| const size_t lenGelu = m * n; | ||
|
|
||
| const size_t lenA_scale = use_mxfp8 ? (lenA + 31) / 32 : 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mxfp8 scale factor has an alignment requirement (128x4 for rowwise and 4x128 for colwise), not simply as ceil(lenA/32):
TransformerEngine/transformer_engine/common/common.h
Lines 689 to 693 in 669b556
| // [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors | |
| constexpr size_t scale_tensor_alignment_X_rowwise = 4; | |
| constexpr size_t scale_tensor_alignment_Y_rowwise = 128; | |
| constexpr size_t scale_tensor_alignment_X_colwise = 128; | |
| constexpr size_t scale_tensor_alignment_Y_colwise = 4; |
See the pytorch mxfp8_tensor.py for more details:
TransformerEngine/transformer_engine/pytorch/tensor/mxfp8_tensor.py
Lines 113 to 130 in 669b556
| scale_inv = torch.zeros( | |
| round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), | |
| round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), | |
| dtype=torch.uint8, | |
| device=device, | |
| ) | |
| # Allocate FP8 data transpose if needed | |
| columnwise_data = None | |
| columnwise_scale_inv = None | |
| if self.columnwise_usage: | |
| columnwise_data = torch.empty_like(data) | |
| columnwise_scale_inv = torch.zeros( | |
| round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), | |
| round_up_to_nearest_multiple(shape[-1], 128), | |
| dtype=torch.uint8, | |
| device=device, | |
| ) |
| fp8e8m0* dB_scale = nullptr; | ||
|
|
||
| // Allocations and H2D transfers | ||
| NVTE_CHECK_CUDA(cudaMalloc(&dA, lenA * sizeof(A_Type))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can adapt existing test tensor classes (
TransformerEngine/tests/cpp/test_common.cu
Line 226 in 669b556
| Tensor::Tensor(const std::string& name, |
| Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); |
In fact, we can change the api of reference computing by taking directly const tensor& therefore we don't need to re-allocate the input and do one extra copy
Description
Introduce a HIP implementation of the GEMM reference computation to speed up the reference computations.
Partly addresses https://github.com/ROCm/frameworks-internal/issues/14746
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: