Skip to content

Conversation

@matthiasdiener
Copy link
Contributor

@matthiasdiener matthiasdiener commented Dec 4, 2025

Description

Introduce a HIP implementation of the GEMM reference computation to speed up the reference computations.

Partly addresses https://github.com/ROCm/frameworks-internal/issues/14746

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Introduce a HIP implementation of the GEMM reference computation

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@matthiasdiener matthiasdiener force-pushed the compute-ref-offload branch 7 times, most recently from 44df11e to e60b912 Compare December 9, 2025 20:53
@matthiasdiener matthiasdiener changed the title [WIP] GEMM reference compute offload GEMM reference computation offload Dec 9, 2025
@matthiasdiener matthiasdiener self-assigned this Dec 9, 2025
@matthiasdiener matthiasdiener requested review from alextmagro and Copilot and removed request for alextmagro December 10, 2025 00:19
@matthiasdiener matthiasdiener marked this pull request as ready for review December 10, 2025 00:19
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a GPU-accelerated implementation of the GEMM reference computation using HIP/CUDA kernels to improve performance over the previous CPU-based implementation. The reference computation is critical for validating GEMM operations, and offloading it to the GPU significantly speeds up testing.

Key Changes

  • Replaced CPU OpenMP-based reference GEMM computation with GPU kernel implementation
  • Introduced compute_ref_kernel to perform matrix multiplication, bias addition, GELU activation, and scaling on GPU
  • Refactored both tensor-wise and MXFP8 code paths to use a common compute_ref_impl function that manages device memory and kernel execution

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@alextmagro alextmagro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Matthias! Looks great, just have a couple performance questions


float val = 0.0f;

if (in_range) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to continue the subsequent blockwise reduction if not in_range?

const size_t lenBias = m;
const size_t lenGelu = m * n;

const size_t lenA_scale = use_mxfp8 ? (lenA + 31) / 32 : 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mxfp8 scale factor has an alignment requirement (128x4 for rowwise and 4x128 for colwise), not simply as ceil(lenA/32):

// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;

See the pytorch mxfp8_tensor.py for more details:

scale_inv = torch.zeros(
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
dtype=torch.uint8,
device=device,
)
# Allocate FP8 data transpose if needed
columnwise_data = None
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty_like(data)
columnwise_scale_inv = torch.zeros(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128),
dtype=torch.uint8,
device=device,
)

fp8e8m0* dB_scale = nullptr;

// Allocations and H2D transfers
NVTE_CHECK_CUDA(cudaMalloc(&dA, lenA * sizeof(A_Type)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can adapt existing test tensor classes (

Tensor::Tensor(const std::string& name,
) and their space allocation functions (
Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING);
) defined in tests/cpp/test_common.cu instead of reinventing.

In fact, we can change the api of reference computing by taking directly const tensor& therefore we don't need to re-allocate the input and do one extra copy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants