Skip to content

Commit fb027d0

Browse files
committed
fix
Signed-off-by: Pawel Gadzinski <[email protected]>
1 parent a702426 commit fb027d0

File tree

5 files changed

+33
-60
lines changed

5 files changed

+33
-60
lines changed

tests/cpp/operator/test_grouped_gemm.cu

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,6 @@ Tensor make_fp8_operand(const std::string& name, const std::vector<size_t>& shap
279279

280280
Tensor make_bf16_operand(const std::string& name, const std::vector<size_t>& shape) {
281281
Tensor t(name, shape, DType::kBFloat16);
282-
// Fill with ones for easier debugging
283-
//fillUniform(&t);
284282
const size_t numel = shape[0] * shape[1];
285283
std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f));
286284
NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(),
@@ -293,8 +291,7 @@ struct TestParams {
293291
bool transa;
294292
bool transb;
295293
ShapeCase shape_case;
296-
bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0)
297-
bool use_split_accumulator = false; // Whether to use split accumulator for FP8 GEMM
294+
bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0)
298295
};
299296

300297
// Returns a vector of (M, N, K) tuples for each GEMM in the group.
@@ -397,7 +394,7 @@ void run_grouped_gemm_case(const TestParams& params) {
397394
false, // grad
398395
workspace_ptrs.data(),
399396
false, // accumulate
400-
params.use_split_accumulator,
397+
false, // use_split_accumulator
401398
0, // sm_count
402399
0);
403400

@@ -450,10 +447,6 @@ void run_grouped_gemm_case(const TestParams& params) {
450447
Tensor setup_ws("setup_ws", std::vector<size_t>{setup_ws_bytes}, DType::kByte);
451448
Tensor cublas_ws("cublas_ws", std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
452449

453-
// Create config with use_split_accumulator setting
454-
transformer_engine::GroupedMatmulConfigWrapper config;
455-
config.set_use_split_accumulator(params.use_split_accumulator);
456-
457450
nvte_grouped_gemm(params.transa,
458451
params.transb,
459452
alpha_tensor.data(),
@@ -464,7 +457,7 @@ void run_grouped_gemm_case(const TestParams& params) {
464457
grouped_D.get_handle(),
465458
setup_ws.data(),
466459
cublas_ws.data(),
467-
config,
460+
nullptr, // config (use defaults)
468461
0);
469462

470463
for (size_t i = 0; i < num_gemms; ++i) {
@@ -502,29 +495,22 @@ std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest
502495
const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") +
503496
"tb" + (info.param.transb ? "T" : "N");
504497
const std::string null_c = info.param.use_null_c ? "_NullC" : "";
505-
const std::string split_acc = info.param.use_split_accumulator ? "_SplitAcc" : "";
506498
return std::string(kInputNames[static_cast<int>(info.param.input_case)]) + "_" +
507-
kShapeNames[static_cast<int>(info.param.shape_case)] + "_" + layout + null_c + split_acc;
499+
kShapeNames[static_cast<int>(info.param.shape_case)] + "_" + layout + null_c;
508500
}
509501

510-
// TestParams: {input_case, transa, transb, shape_case, use_null_c, use_split_accumulator}
502+
// TestParams: {input_case, transa, transb, shape_case, use_null_c}
511503
const std::vector<TestParams> kTestParams = {
512-
// Basic tests (no split accumulator)
513-
{InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false, false},
514-
{InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false, false},
515-
{InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false, false},
516-
{InputCase::kBF16, true, false, ShapeCase::kSameFirst, false, false},
517-
{InputCase::kBF16, false, true, ShapeCase::kSameLast, false, false},
518-
{InputCase::kBF16, false, false, ShapeCase::kAllSame, false, false},
519-
{InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false, false},
504+
// Basic tests
505+
{InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false},
506+
{InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false},
507+
{InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false},
508+
{InputCase::kBF16, true, false, ShapeCase::kSameFirst, false},
509+
{InputCase::kBF16, false, true, ShapeCase::kSameLast, false},
510+
{InputCase::kBF16, false, false, ShapeCase::kAllSame, false},
511+
{InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false},
520512
// Test NULL C (valid when beta=0)
521-
{InputCase::kBF16, false, false, ShapeCase::kAllSame, true, false},
522-
523-
// Split accumulator tests
524-
{InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false, true},
525-
{InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false, true},
526-
{InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false, true},
527-
{InputCase::kFP8Current, true, false, ShapeCase::kSameFirst, false, true},
513+
{InputCase::kBF16, false, false, ShapeCase::kAllSame, true},
528514
};
529515

530516
INSTANTIATE_TEST_SUITE_P(OperatorTest,

transformer_engine/common/gemm/config.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,6 @@ void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
154154
case kNVTEGroupedMatmulConfigAvgK:
155155
std::memcpy(buf, &config_.avg_k, attr_size);
156156
break;
157-
case kNVTEGroupedMatmulConfigUseSplitAccumulator:
158-
std::memcpy(buf, &config_.use_split_accumulator, attr_size);
159-
break;
160157
case kNVTEGroupedMatmulConfigSMCount:
161158
std::memcpy(buf, &config_.sm_count, attr_size);
162159
break;
@@ -195,9 +192,6 @@ void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
195192
std::memcpy(&config_.avg_k, buf, attr_size);
196193
config_.avg_k_set = true;
197194
break;
198-
case kNVTEGroupedMatmulConfigUseSplitAccumulator:
199-
std::memcpy(&config_.use_split_accumulator, buf, attr_size);
200-
break;
201195
case kNVTEGroupedMatmulConfigSMCount:
202196
std::memcpy(&config_.sm_count, buf, attr_size);
203197
break;

transformer_engine/common/gemm/config.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ struct GroupedMatmulConfig {
4040
int64_t avg_n = 0;
4141
int64_t avg_k = 0;
4242

43-
// Whether to use split accumulator for FP8 GEMM (more accurate but slower)
44-
bool use_split_accumulator = true;
45-
4643
// Number of streaming multiprocessors to use in GEMM kernel
4744
int sm_count = 0;
4845

@@ -55,7 +52,6 @@ struct GroupedMatmulConfig {
5552
sizeof(int64_t), // avg_m
5653
sizeof(int64_t), // avg_n
5754
sizeof(int64_t), // avg_k
58-
sizeof(bool), // use_split_accumulator
5955
sizeof(int) // sm_count
6056
};
6157
};

transformer_engine/common/gemm/cublaslt_grouped_gemm.cu

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -310,15 +310,17 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA,
310310

311311
// For column-major layout: leading dimension is the number of rows in storage.
312312
// If columnwise data was chosen, storage is already transposed.
313-
int *rowa = A_sel.use_columnwise ? ws.M : (A_sel.trans ? ws.K : ws.M);
314-
int *cola = A_sel.use_columnwise ? ws.K : (A_sel.trans ? ws.M : ws.K);
315-
int *lda = rowa;
316-
int *rowb = B_sel.use_columnwise ? ws.N : (B_sel.trans ? ws.N : ws.K);
317-
int *colb = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N);
318-
int *ldb = rowb;
319-
320-
NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rowa, cola, lda));
321-
NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rowb, colb, ldb));
313+
// Storage dimensions for A: rows_A x cols_A with leading dimension lda_storage
314+
int *rows_A = A_sel.use_columnwise ? ws.M : (A_sel.trans ? ws.K : ws.M);
315+
int *cols_A = A_sel.use_columnwise ? ws.K : (A_sel.trans ? ws.M : ws.K);
316+
int *lda_storage = rows_A;
317+
// Storage dimensions for B: rows_B x cols_B with leading dimension ldb_storage
318+
int *rows_B = B_sel.use_columnwise ? ws.N : (B_sel.trans ? ws.N : ws.K);
319+
int *cols_B = B_sel.use_columnwise ? ws.K : (B_sel.trans ? ws.K : ws.N);
320+
int *ldb_storage = rows_B;
321+
322+
NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, rows_A, cols_A, lda_storage));
323+
NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, rows_B, cols_B, ldb_storage));
322324
NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.M, ws.N, ws.M));
323325
NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.M, ws.N, ws.M));
324326
}
@@ -442,14 +444,15 @@ __global__ void setup_grouped_gemm_kernel(
442444
D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last);
443445

444446
// Compute data pointers
447+
// Note: const_cast is safe here - cuBLAS requires void** but won't modify A/B/C data
445448
A_ptrs[idx] = const_cast<char *>(a_base) + a_offset * a_elem_size;
446449
B_ptrs[idx] = const_cast<char *>(b_base) + b_offset * b_elem_size;
447450
C_ptrs[idx] = const_cast<char *>(c_base) + c_offset * c_elem_size;
448451
D_ptrs[idx] = d_base + d_offset * d_elem_size;
449452

450-
// Compute M, N, K dimensions
451-
// Test stores A as {K,M} when !transa, {M,K} when transa
452-
// Test stores B as {N,K} when !transb, {K,N} when transb
453+
// Compute M, N, K dimensions from tensor shapes
454+
// Input A is stored as {K,M} when !transa, {M,K} when transa
455+
// Input B is stored as {N,K} when !transb, {K,N} when transb
453456
M[idx] = static_cast<int>(transa ? a_first : a_last);
454457
K[idx] = static_cast<int>(transa ? a_last : a_first);
455458
N[idx] = static_cast<int>(transb ? b_last : b_first);
@@ -570,9 +573,11 @@ void nvte_grouped_gemm(int transa, int transb, const NVTETensor alpha, const NVT
570573

571574
// Set fast accumulation mode for FP8
572575
// Fast accumulation: 0 = split accumulator (more accurate), 1 = fast accumulator
576+
// Note: cuBLASLt grouped GEMM API does not support configurable split accumulator,
577+
// we always use fast accumulator for performance.
573578
const bool is_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype);
574579
if (is_fp8) {
575-
int8_t fastAccuMode = config_.use_split_accumulator ? 0 : 1;
580+
int8_t fastAccuMode = 1; // Always use fast accumulator
576581
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
577582
&matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode)));
578583
}

transformer_engine/common/include/transformer_engine/gemm.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,8 @@ enum NVTEGroupedMatmulConfigAttribute {
8282
* computed automatically from A's logical shape.
8383
*/
8484
kNVTEGroupedMatmulConfigAvgK = 2,
85-
/*! Whether to use split accumulator for FP8 GEMM. */
86-
kNVTEGroupedMatmulConfigUseSplitAccumulator = 3,
8785
/*! Number of streaming multiprocessors to use in GEMM kernel. */
88-
kNVTEGroupedMatmulConfigSMCount = 4,
86+
kNVTEGroupedMatmulConfigSMCount = 3,
8987
kNVTEGroupedMatmulConfigNumAttributes
9088
};
9189

@@ -487,12 +485,6 @@ class GroupedMatmulConfigWrapper {
487485
sizeof(int64_t));
488486
}
489487

490-
/*! \brief Set whether to use split accumulator for FP8 GEMM. */
491-
void set_use_split_accumulator(bool use_split_accumulator) {
492-
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigUseSplitAccumulator,
493-
&use_split_accumulator, sizeof(bool));
494-
}
495-
496488
/*! \brief Set number of streaming multiprocessors to use. */
497489
void set_sm_count(int sm_count) {
498490
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigSMCount,

0 commit comments

Comments
 (0)