Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
333 changes: 269 additions & 64 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,248 @@ using TShape = std::vector<size_t>;
} // namespace


float ref_gelu(float x){
__device__ __host__ __forceinline__ float ref_gelu(float x){
float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf;
}

template <typename A_Type, typename B_Type, typename Bias_Type,
typename Gelu_Type, typename D_Type>
__global__ void compute_ref_kernel(
const A_Type* __restrict__ a_data,
const B_Type* __restrict__ b_data,
float a_scale_inv_scalar, // used when mxfp8 == false
float b_scale_inv_scalar,
const fp8e8m0* __restrict__ a_scale_inv_mxfp8, // used when mxfp8 == true
const fp8e8m0* __restrict__ b_scale_inv_mxfp8,
const Bias_Type* __restrict__ bias_data,
float d_scale,
size_t m, size_t k, size_t n,
D_Type* __restrict__ d_data,
float* __restrict__ d_amax,
Gelu_Type* __restrict__ gelu_data,
bool transa,
bool transb,
bool is_fp8_output)
{
const size_t jj = blockIdx.x * blockDim.x + threadIdx.x;
const size_t ii = blockIdx.y * blockDim.y + threadIdx.y;

const bool in_range = (ii < m) && (jj < n);

float val = 0.0f;

if (in_range) {
for (size_t kk = 0; kk < k; ++kk) {
const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii);
const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk);

float a_scale_inv_val = a_scale_inv_scalar;
float b_scale_inv_val = b_scale_inv_scalar;

if (a_scale_inv_mxfp8) {
const size_t a_scale_idx =
transa ? (a_idx / 32) : ((kk / 32) * m + ii);
const size_t b_scale_idx =
transb ? ((kk / 32) * n + jj) : (b_idx / 32);

const float a_byte = static_cast<float>(a_scale_inv_mxfp8[a_scale_idx]);
const float b_byte = static_cast<float>(b_scale_inv_mxfp8[b_scale_idx]);

a_scale_inv_val = exp2f(a_byte - 127.0f);
b_scale_inv_val = exp2f(b_byte - 127.0f);
}

const float a_val = static_cast<float>(a_data[a_idx]);
const float b_val = static_cast<float>(b_data[b_idx]);

val += a_scale_inv_val * a_val * b_scale_inv_val * b_val;
}

if (bias_data) {
val += static_cast<float>(bias_data[ii]);
}

if (gelu_data) {
gelu_data[ii + jj * m] = static_cast<Gelu_Type>(val);
val = ref_gelu(val);
}

const float scaled = val * d_scale;
d_data[ii + jj * m] = static_cast<D_Type>(scaled);
}

// Blockwise reduction for amax
if (is_fp8_output && d_amax) {
const int tid = threadIdx.y * blockDim.x + threadIdx.x;
const int nthreads = blockDim.x * blockDim.y;

extern __shared__ float s_amax[];

// Out-of-range threads contribute 0
s_amax[tid] = in_range ? fabsf(val) : 0.0f;
__syncthreads();

for (int offset = nthreads / 2; offset > 0; offset /= 2) {
if (tid < offset) {
s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]);
}
__syncthreads();
}

if (tid == 0) {
const float block_max = s_amax[0];
atomicMax(d_amax, block_max);
}
}
}

// Common implementation used by both tensor-wise and MXFP8 frontends
template <typename A_Type, typename B_Type, typename Bias_Type,
typename Gelu_Type, typename D_Type>
static void compute_ref_impl(
const A_Type* a_data,
const B_Type* b_data,
float a_scale_inv_scalar, // used when mxfp8 == false
float b_scale_inv_scalar,
const fp8e8m0* a_scale_inv_mxfp8, // used when mxfp8 == true
const fp8e8m0* b_scale_inv_mxfp8,
const Bias_Type* bias_data,
float d_scale,
size_t m, size_t k, size_t n,
D_Type* d_data,
float* d_amax_host,
Gelu_Type* gelu_data,
bool transa,
bool transb)
{
using transformer_engine::DType;
using ::TypeInfo;
using ::isFp8Type;

const bool use_mxfp8 = (a_scale_inv_mxfp8 != nullptr);

const DType dtype = TypeInfo<D_Type>::dtype;
const bool is_fp8_output = isFp8Type(dtype);

const size_t lenA = m * k;
const size_t lenB = k * n;
const size_t lenD = m * n;
const size_t lenBias = m;
const size_t lenGelu = m * n;

const size_t lenA_scale = use_mxfp8 ? (lenA + 31) / 32 : 0;
const size_t lenB_scale = use_mxfp8 ? (lenB + 31) / 32 : 0;

A_Type* dA = nullptr;
B_Type* dB = nullptr;
Bias_Type* dBias = nullptr;
D_Type* dD = nullptr;
Gelu_Type* dGelu = nullptr;
float* dAmax = nullptr;
fp8e8m0* dA_scale = nullptr;
fp8e8m0* dB_scale = nullptr;

// Allocations and H2D transfers
NVTE_CHECK_CUDA(cudaMalloc(&dA, lenA * sizeof(A_Type)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can adapt existing test tensor classes (

Tensor::Tensor(const std::string& name,
) and their space allocation functions (
Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING);
) defined in tests/cpp/test_common.cu instead of reinventing.

In fact, we can change the api of reference computing by taking directly const tensor& therefore we don't need to re-allocate the input and do one extra copy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of 3ecea7f? This also merges the mxfp8/non-mxfp8 paths.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for consolidating with existing apis in test_common.cu.

In fact, I still see some cudaMalloc and cudaFree, which can be replaced by using existing test tensor class apis.
For example, the device pointer for scale (

NVTE_CHECK_CUDA(cudaMalloc(&d_a_scale_packed, a_scale_packed.size() * sizeof(fp8e8m0)));
), its corresponding test tensor allocation can be found here:
if (rowwise) {
(void)cudaMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*)
(void)cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size);
rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(rowwise_scale_size);
std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0);
auto scale_dtype = rowwise_scale_meta.type;
tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape);
}
if (columnwise) {
(void)cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*)
(void)cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size);
columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(columnwise_scale_size);
std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0);
auto scale_dtype = colwise_scale_meta.type;
tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced the remaining raw allocations in the reference path with test::Tensor for the temporary device buffers (RefD/RefGelu/RefAmax) in e11e400.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Yeah, it indeed saved some cudaMalloc/cudaFrees.

How about we put the RefD instantiation inside PerformTest, and pass the Tensor RefD (including its RefAmax D) and RefPreGeluOut to run_reference directly (instead of std::unique_ptr<D_Type[]>& ref_D, float* ref_amax_d, std::unique_ptr<Gelu_Type[]>& ref_pre_gelu_out). Then this can save some ref cpu ptr allocation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of 325ece6?

NVTE_CHECK_CUDA(cudaMalloc(&dB, lenB * sizeof(B_Type)));
NVTE_CHECK_CUDA(cudaMalloc(&dD, lenD * sizeof(D_Type)));

NVTE_CHECK_CUDA(cudaMemcpy(
dA, a_data, lenA * sizeof(A_Type), cudaMemcpyHostToDevice));
NVTE_CHECK_CUDA(cudaMemcpy(
dB, b_data, lenB * sizeof(B_Type), cudaMemcpyHostToDevice));

if (bias_data) {
NVTE_CHECK_CUDA(cudaMalloc(&dBias, lenBias * sizeof(Bias_Type)));
NVTE_CHECK_CUDA(cudaMemcpy(
dBias, bias_data, lenBias * sizeof(Bias_Type),
cudaMemcpyHostToDevice));
}

if (gelu_data) {
NVTE_CHECK_CUDA(cudaMalloc(&dGelu, lenGelu * sizeof(Gelu_Type)));
NVTE_CHECK_CUDA(cudaMemset(dGelu, 0, lenGelu * sizeof(Gelu_Type)));
}

if (use_mxfp8) {
NVTE_CHECK_CUDA(cudaMalloc(&dA_scale, lenA_scale * sizeof(fp8e8m0)));
NVTE_CHECK_CUDA(cudaMalloc(&dB_scale, lenB_scale * sizeof(fp8e8m0)));
NVTE_CHECK_CUDA(cudaMemcpy(
dA_scale, a_scale_inv_mxfp8, lenA_scale * sizeof(fp8e8m0),
cudaMemcpyHostToDevice));
NVTE_CHECK_CUDA(cudaMemcpy(
dB_scale, b_scale_inv_mxfp8, lenB_scale * sizeof(fp8e8m0),
cudaMemcpyHostToDevice));
}

if (is_fp8_output && d_amax_host) {
NVTE_CHECK_CUDA(cudaMalloc(&dAmax, sizeof(float)));
NVTE_CHECK_CUDA(cudaMemset(dAmax, 0, sizeof(float)));
}

// Kernel launch
dim3 block(16, 16);
dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y);

const int nthreads = block.x * block.y;
size_t shmem_bytes = nthreads * sizeof(float);

compute_ref_kernel<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>
<<<grid, block, shmem_bytes, 0>>>(
dA,
dB,
a_scale_inv_scalar,
b_scale_inv_scalar,
dA_scale,
dB_scale,
dBias,
d_scale,
m, k, n,
dD,
dAmax,
dGelu,
transa,
transb,
is_fp8_output);

NVTE_CHECK_CUDA(cudaGetLastError());
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have both reference and target runs on GPU, we can just run one single device synchronization at the very end of both runs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 325ece6


// D2H copies
NVTE_CHECK_CUDA(cudaMemcpy(
d_data, dD, lenD * sizeof(D_Type), cudaMemcpyDeviceToHost));

if (gelu_data) {
NVTE_CHECK_CUDA(cudaMemcpy(
gelu_data, dGelu, lenGelu * sizeof(Gelu_Type),
cudaMemcpyDeviceToHost));
}

if (is_fp8_output && d_amax_host) {
NVTE_CHECK_CUDA(cudaMemcpy(
d_amax_host, dAmax, sizeof(float), cudaMemcpyDeviceToHost));
} else if (d_amax_host) {
*d_amax_host = 0.0f;
}

// cleanup
NVTE_CHECK_CUDA(cudaFree(dA));
NVTE_CHECK_CUDA(cudaFree(dB));
NVTE_CHECK_CUDA(cudaFree(dD));
if (dBias)
NVTE_CHECK_CUDA(cudaFree(dBias));
if (dGelu)
NVTE_CHECK_CUDA(cudaFree(dGelu));
if (dAmax)
NVTE_CHECK_CUDA(cudaFree(dAmax));
if (dA_scale)
NVTE_CHECK_CUDA(cudaFree(dA_scale));
if (dB_scale)
NVTE_CHECK_CUDA(cudaFree(dB_scale));
}


template <typename A_Type, typename B_Type, typename Bias_Type, typename Gelu_Type, typename D_Type>
void compute_ref(
const A_Type* a_data,
Expand All @@ -71,36 +308,21 @@ void compute_ref(
bool transa,
bool transb){

float ref_d_amax = 0;

#pragma omp parallel for schedule(static) collapse(2) reduction(max: ref_d_amax) proc_bind(spread)
for(size_t ii = 0; ii < m; ii++){
for(size_t jj = 0; jj < n; jj++){
float val = 0;
for(size_t kk = 0; kk < k; kk++){
float a_val = transa ? a_data[kk + ii*k] : a_data[ii + kk*m];
float b_val = transb ? b_data[jj + kk*n] : b_data[kk + jj*k];
val += a_scale_inv*a_val*b_scale_inv*b_val;
}
if(bias_data){
val += (float)bias_data[ii];
}
if(ref_gelu_data){
ref_gelu_data[ii + jj*m] = (Gelu_Type)(val);
val = ref_gelu(val);
}
ref_d_data[ii+jj*m] = (D_Type)(val*d_scale);
// update ref_d_amax if in fp8
DType dtype = TypeInfo<D_Type>::dtype;
if(isFp8Type(dtype)){
ref_d_amax = std::max(ref_d_amax, std::fabs(val));
}
}
}
if (ref_d_amax_ptr)
{
*ref_d_amax_ptr = ref_d_amax;
}
compute_ref_impl<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(
a_data,
b_data,
/*a_scale_inv_scalar=*/a_scale_inv,
/*b_scale_inv_scalar=*/b_scale_inv,
/*a_scale_inv_mxfp8=*/nullptr,
/*b_scale_inv_mxfp8=*/nullptr,
bias_data,
d_scale,
m, k, n,
ref_d_data,
ref_d_amax_ptr,
ref_gelu_data,
transa,
transb);
}

template <typename A_Type, typename B_Type, typename Bias_Type, typename Gelu_Type, typename D_Type>
Expand All @@ -118,38 +340,21 @@ void compute_mxfp8_ref(
bool transa,
bool transb){

float ref_d_amax = 0;

#pragma omp parallel for schedule(static) collapse(2) reduction(max: ref_d_amax) proc_bind(spread)
for(size_t ii = 0; ii < m; ii++){
for(size_t jj = 0; jj < n; jj++){
float val = 0;
for(size_t kk = 0; kk < k; kk++){
size_t a_idx = transa ? (ii*k + kk) : (kk*m + ii);
size_t b_idx = transb ? (kk*n + jj) : (jj*k + kk);
float a_scale_inv_val = std::exp2f(a_scale_inv_data[transa ? a_idx/32 : (kk/32 * m + ii)] - 127);
float b_scale_inv_val = std::exp2f(b_scale_inv_data[transb ? (kk/32 * n + jj) : b_idx/32] - 127);
val += a_scale_inv_val * (float)a_data[a_idx] * b_scale_inv_val * (float)b_data[b_idx];
}
if(bias_data){
val += (float)bias_data[ii];
}
if(ref_gelu_data){
ref_gelu_data[ii + jj*m] = (Gelu_Type)(val);
val = ref_gelu(val);
}
ref_d_data[ii+jj*m] = (D_Type)(val*d_scale);
// update ref_d_amax if in fp8
DType dtype = TypeInfo<D_Type>::dtype;
if(isFp8Type(dtype)){
ref_d_amax = std::max(ref_d_amax, std::fabs(val));
}
}
}
if (ref_d_amax_ptr)
{
*ref_d_amax_ptr = ref_d_amax;
}
compute_ref_impl<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(
a_data,
b_data,
/*a_scale_inv_scalar=*/1.0f,
/*b_scale_inv_scalar=*/1.0f,
/*a_scale_inv_mxfp8=*/a_scale_inv_data,
/*b_scale_inv_mxfp8=*/b_scale_inv_data,
bias_data,
d_scale,
m, k, n,
ref_d_data,
ref_d_amax_ptr,
ref_gelu_data,
transa,
transb);
}

template <typename Type>
Expand Down Expand Up @@ -371,7 +576,7 @@ void performTest(const TestParams& params) {
pre_gelu_out.to_cpu();
}

//perform the gemm in CPU
//perform the reference gemm on GPU
std::unique_ptr<D_Type[]> ref_D = std::make_unique<D_Type[]>(params.m*params.n);
std::unique_ptr<Gelu_Type[]> ref_pre_gelu_out;
if(params.use_gelu){
Expand Down