@@ -1253,12 +1253,14 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor
12531253 " Grouped GEMM: A and D must have the same num_tensors" );
12541254
12551255 // Validate alpha/beta have per-matrix values
1256- const size_t alpha_numel = alpha_tensor->data .shape .numel ();
1257- const size_t beta_numel = beta_tensor->data .shape .numel ();
1258- NVTE_CHECK (alpha_numel == num_tensors, " Grouped GEMM: alpha must have num_tensors (" , num_tensors,
1259- " ) elements, got " , alpha_numel);
1260- NVTE_CHECK (beta_numel == num_tensors, " Grouped GEMM: beta must have num_tensors (" , num_tensors,
1261- " ) elements, got " , beta_numel);
1256+ const size_t alpha_numel = alpha_tensor->data .numel ();
1257+ const size_t beta_numel = beta_tensor->data .numel ();
1258+ NVTE_CHECK (alpha_numel == num_tensors,
1259+ " Grouped GEMM: alpha must have num_tensors (" , num_tensors, " ) elements, got " ,
1260+ alpha_numel);
1261+ NVTE_CHECK (beta_numel == num_tensors,
1262+ " Grouped GEMM: beta must have num_tensors (" , num_tensors, " ) elements, got " ,
1263+ beta_numel);
12621264
12631265 auto is_fp8_or_16bit = [](transformer_engine::DType dtype) {
12641266 return dtype == transformer_engine::DType::kFloat8E4M3 ||
0 commit comments