@@ -279,8 +279,6 @@ Tensor make_fp8_operand(const std::string& name, const std::vector<size_t>& shap
279279
280280Tensor make_bf16_operand (const std::string& name, const std::vector<size_t >& shape) {
281281 Tensor t (name, shape, DType::kBFloat16 );
282- // Fill with ones for easier debugging
283- // fillUniform(&t);
284282 const size_t numel = shape[0 ] * shape[1 ];
285283 std::vector<__nv_bfloat16> ones (numel, __float2bfloat16 (1 .0f ));
286284 NVTE_CHECK_CUDA (cudaMemcpy (t.rowwise_dptr (), ones.data (),
@@ -293,8 +291,7 @@ struct TestParams {
293291 bool transa;
294292 bool transb;
295293 ShapeCase shape_case;
296- bool use_null_c = false ; // When true, pass nullptr for C (valid when beta=0)
297- bool use_split_accumulator = false ; // Whether to use split accumulator for FP8 GEMM
294+ bool use_null_c = false ; // When true, pass nullptr for C (valid when beta=0)
298295};
299296
300297// Returns a vector of (M, N, K) tuples for each GEMM in the group.
@@ -397,7 +394,7 @@ void run_grouped_gemm_case(const TestParams& params) {
397394 false , // grad
398395 workspace_ptrs.data (),
399396 false , // accumulate
400- params. use_split_accumulator ,
397+ false , // use_split_accumulator
401398 0 , // sm_count
402399 0 );
403400
@@ -450,10 +447,6 @@ void run_grouped_gemm_case(const TestParams& params) {
450447 Tensor setup_ws (" setup_ws" , std::vector<size_t >{setup_ws_bytes}, DType::kByte );
451448 Tensor cublas_ws (" cublas_ws" , std::vector<size_t >{cublas_ws_bytes}, DType::kByte );
452449
453- // Create config with use_split_accumulator setting
454- transformer_engine::GroupedMatmulConfigWrapper config;
455- config.set_use_split_accumulator (params.use_split_accumulator );
456-
457450 nvte_grouped_gemm (params.transa ,
458451 params.transb ,
459452 alpha_tensor.data (),
@@ -464,7 +457,7 @@ void run_grouped_gemm_case(const TestParams& params) {
464457 grouped_D.get_handle (),
465458 setup_ws.data (),
466459 cublas_ws.data (),
467- config,
460+ nullptr , // config (use defaults)
468461 0 );
469462
470463 for (size_t i = 0 ; i < num_gemms; ++i) {
@@ -502,29 +495,22 @@ std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest
502495 const std::string layout = std::string (" ta" ) + (info.param .transa ? " T" : " N" ) +
503496 " tb" + (info.param .transb ? " T" : " N" );
504497 const std::string null_c = info.param .use_null_c ? " _NullC" : " " ;
505- const std::string split_acc = info.param .use_split_accumulator ? " _SplitAcc" : " " ;
506498 return std::string (kInputNames [static_cast <int >(info.param .input_case )]) + " _" +
507- kShapeNames [static_cast <int >(info.param .shape_case )] + " _" + layout + null_c + split_acc ;
499+ kShapeNames [static_cast <int >(info.param .shape_case )] + " _" + layout + null_c;
508500}
509501
510- // TestParams: {input_case, transa, transb, shape_case, use_null_c, use_split_accumulator }
502+ // TestParams: {input_case, transa, transb, shape_case, use_null_c}
511503const std::vector<TestParams> kTestParams = {
512- // Basic tests (no split accumulator)
513- {InputCase::kFP8Current , true , false , ShapeCase::kAllDifferent , false , false },
514- {InputCase::kFP8Current , false , true , ShapeCase::kAllDifferent , false , false },
515- {InputCase::kFP8Current , false , false , ShapeCase::kAllSame , false , false },
516- {InputCase::kBF16 , true , false , ShapeCase::kSameFirst , false , false },
517- {InputCase::kBF16 , false , true , ShapeCase::kSameLast , false , false },
518- {InputCase::kBF16 , false , false , ShapeCase::kAllSame , false , false },
519- {InputCase::kBF16 , true , true , ShapeCase::kAllDifferent , false , false },
504+ // Basic tests
505+ {InputCase::kFP8Current , true , false , ShapeCase::kAllDifferent , false },
506+ {InputCase::kFP8Current , false , true , ShapeCase::kAllDifferent , false },
507+ {InputCase::kFP8Current , false , false , ShapeCase::kAllSame , false },
508+ {InputCase::kBF16 , true , false , ShapeCase::kSameFirst , false },
509+ {InputCase::kBF16 , false , true , ShapeCase::kSameLast , false },
510+ {InputCase::kBF16 , false , false , ShapeCase::kAllSame , false },
511+ {InputCase::kBF16 , true , true , ShapeCase::kAllDifferent , false },
520512 // Test NULL C (valid when beta=0)
521- {InputCase::kBF16 , false , false , ShapeCase::kAllSame , true , false },
522-
523- // Split accumulator tests
524- {InputCase::kFP8Current , true , false , ShapeCase::kAllDifferent , false , true },
525- {InputCase::kFP8Current , false , true , ShapeCase::kAllDifferent , false , true },
526- {InputCase::kFP8Current , false , false , ShapeCase::kAllSame , false , true },
527- {InputCase::kFP8Current , true , false , ShapeCase::kSameFirst , false , true },
513+ {InputCase::kBF16 , false , false , ShapeCase::kAllSame , true },
528514};
529515
530516INSTANTIATE_TEST_SUITE_P (OperatorTest,
0 commit comments