Skip to content

Commit 1167f75

Browse files
committed
Fix alpha/beta numel - use SimpleTensor::numel()
Signed-off-by: Piotr Gadzinski <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
1 parent 101766b commit 1167f75

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

transformer_engine/common/gemm/cublaslt_gemm.cu

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,12 +1253,14 @@ inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor
12531253
"Grouped GEMM: A and D must have the same num_tensors");
12541254

12551255
// Validate alpha/beta have per-matrix values
1256-
const size_t alpha_numel = alpha_tensor->data.shape.numel();
1257-
const size_t beta_numel = beta_tensor->data.shape.numel();
1258-
NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors,
1259-
") elements, got ", alpha_numel);
1260-
NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors,
1261-
") elements, got ", beta_numel);
1256+
const size_t alpha_numel = alpha_tensor->data.numel();
1257+
const size_t beta_numel = beta_tensor->data.numel();
1258+
NVTE_CHECK(alpha_numel == num_tensors,
1259+
"Grouped GEMM: alpha must have num_tensors (", num_tensors, ") elements, got ",
1260+
alpha_numel);
1261+
NVTE_CHECK(beta_numel == num_tensors,
1262+
"Grouped GEMM: beta must have num_tensors (", num_tensors, ") elements, got ",
1263+
beta_numel);
12621264

12631265
auto is_fp8_or_16bit = [](transformer_engine::DType dtype) {
12641266
return dtype == transformer_engine::DType::kFloat8E4M3 ||

0 commit comments

Comments
 (0)