Skip to content

Commit de51c96

Browse files
[NVFP4][MOE] Bug Fix for NVFP4 Grouped Quant (#2564)
* fix Signed-off-by: Zhongbo Zhu <[email protected]> * resolve review comments Signed-off-by: Zhongbo Zhu <[email protected]> * Comment tweaks Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Zhongbo Zhu <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]>
1 parent 702fc5e commit de51c96

File tree

3 files changed

+30
-14
lines changed

3 files changed

+30
-14
lines changed

transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,8 +1125,9 @@ template <bool kEnableStochasticRounding, bool kEnableRHTColQuant, bool kEnableR
11251125
void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_size, TA const *A,
11261126
TB const *B, TQA *QA, TSFA *SFA,
11271127
MultiAmaxHadamardCastFusionArgs &args,
1128-
const size_t *rng_state, uint32_t sm_count,
1129-
cudaStream_t stream, int k_tile_size = 1024) {
1128+
const size_t *rng_state, uint32_t *tile_scheduler_workspace,
1129+
uint32_t sm_count, cudaStream_t stream,
1130+
int k_tile_size = 1024) {
11301131
using namespace cute;
11311132
static int constexpr SFVecSize = 16;
11321133
static int constexpr RhtTensorSize = 16;
@@ -1295,10 +1296,9 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz
12951296
NVTE_CHECK_CUDA(
12961297
cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
12971298

1298-
// Allocate workspace and set to zero
1299-
void *tile_scheduler_workspace = nullptr;
1300-
NVTE_CHECK_CUDA(cudaMallocAsync(&tile_scheduler_workspace, sizeof(uint32_t), stream));
1301-
NVTE_CHECK_CUDA(cudaMemsetAsync(tile_scheduler_workspace, 0, sizeof(uint32_t), stream));
1299+
// Set workspace and set to zero
1300+
NVTE_CHECK_CUDA(cudaMemsetAsync(reinterpret_cast<void *>(tile_scheduler_workspace), 0,
1301+
sizeof(uint32_t), stream));
13021302

13031303
// Launch kernel
13041304
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream};
@@ -1308,8 +1308,6 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz
13081308
tile_scheduler_workspace, mma, rng_state);
13091309
NVTE_CHECK_CUDA(cudaGetLastError());
13101310
NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed.");
1311-
1312-
NVTE_CHECK_CUDA(cudaFreeAsync(tile_scheduler_workspace, stream));
13131311
}
13141312

13151313
} // namespace
@@ -1318,7 +1316,8 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz
13181316
void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector<Tensor *> &output_list,
13191317
const size_t *split_sections, size_t num_tensors,
13201318
const Tensor &hadamard_matrix_,
1321-
QuantizationConfig &quant_config, cudaStream_t stream) {
1319+
QuantizationConfig &quant_config, Tensor &quant_workspace,
1320+
cudaStream_t stream) {
13221321
NVTE_API_CALL(group_hadamard_transform_cast_fusion);
13231322

13241323
using transformer_engine::detail::kMaxTensorsPerKernel;
@@ -1399,6 +1398,12 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector<Tens
13991398
rng_state = reinterpret_cast<const size_t *>(rng_state_tensor.data.dptr);
14001399
}
14011400

1401+
uint32_t *tile_scheduler_workspace = nullptr;
1402+
NVTE_CHECK(quant_workspace.data.dptr != nullptr, "Quantization workspace must be provided.");
1403+
NVTE_CHECK(quant_workspace.data.buffer_size_bytes() >= sizeof(uint32_t),
1404+
"Quantization workspace must be at least 4 bytes.");
1405+
tile_scheduler_workspace = reinterpret_cast<uint32_t *>(quant_workspace.data.dptr);
1406+
14021407
// Template arguments
14031408
using TA = cute::bfloat16_t;
14041409
using TB = cute::bfloat16_t;
@@ -1461,7 +1466,9 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector<Tens
14611466
/*QA=*/reinterpret_cast<TQA *>(rowwise_data_base_ptr),
14621467
/*SFA=*/reinterpret_cast<TSFA *>(rowwise_scale_inv_base_ptr),
14631468
/*args=*/kernel_args,
1464-
/*rng_state=*/rng_state, /*sm_count=*/sm_count,
1469+
/*rng_state=*/rng_state,
1470+
/*tile_scheduler_workspace=*/tile_scheduler_workspace,
1471+
/*sm_count=*/sm_count,
14651472
/*stream=*/stream, /*k_tile_size=*/k_tile_size);
14661473
} else {
14671474
NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=",
@@ -1478,7 +1485,7 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso
14781485
const size_t *split_sections,
14791486
const size_t num_tensors,
14801487
const NVTEQuantizationConfig quant_config,
1481-
cudaStream_t stream) {
1488+
NVTETensor quant_workspace, cudaStream_t stream) {
14821489
NVTE_API_CALL(nvte_group_hadamard_transform_cast_fusion);
14831490
using namespace transformer_engine;
14841491
NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0.");
@@ -1489,6 +1496,8 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso
14891496
output_list[i] = convertNVTETensorCheck(outputs[i]);
14901497
}
14911498

1499+
Tensor *quant_workspace_tensor = convertNVTETensorCheck(quant_workspace);
1500+
14921501
QuantizationConfig quant_config_cpp;
14931502
if (quant_config != nullptr) {
14941503
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
@@ -1497,5 +1506,5 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso
14971506
// Call the multi-tensor Hadamard transform amax implementation.
14981507
group_hadamard_transform_cast_fusion(*input_tensor, output_list, split_sections, num_tensors,
14991508
*convertNVTETensorCheck(hadamard_matrix), quant_config_cpp,
1500-
stream);
1509+
*quant_workspace_tensor, stream);
15011510
}

transformer_engine/common/include/transformer_engine/hadamard_transform.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,14 @@ void nvte_group_hadamard_transform_cast_fusion_columnwise(
115115
* \param[in] split_sections Array specifying splits in dimension 0 for each output tensor.
116116
* \param[in] num_tensors Number of output tensors, must be > 0.
117117
* \param[in] quant_config Quantization configuration.
118+
* \param[in] quant_workspace Workspace buffer. Must be at least 4 bytes.
118119
* \param[in] stream CUDA stream used for the operation.
119120
*/
120121
void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor* outputs,
121122
const NVTETensor hadamard_matrix,
122123
const size_t* split_sections, size_t num_tensors,
123124
const NVTEQuantizationConfig quant_config,
124-
cudaStream_t stream);
125+
NVTETensor quant_workspace, cudaStream_t stream);
125126

126127
#ifdef __cplusplus
127128
} // extern "C"

transformer_engine/pytorch/csrc/extensions/cast.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,10 +872,16 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input,
872872
auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer.rht_matrix);
873873

874874
if (all_aligned_token_dim) {
875+
// allocate a tile scheduler workspace
876+
auto tile_scheduler_workspace_torch =
877+
at::empty({1}, at::device(at::kCUDA).dtype(torch::kInt32));
878+
auto nvte_tile_scheduler_workspace =
879+
makeTransformerEngineTensor(tile_scheduler_workspace_torch);
875880
// call the fully-fused grouped kernel for rowwise quantization & colwise RHT quantization transpose
876881
nvte_group_hadamard_transform_cast_fusion(
877882
input.data(), reinterpret_cast<NVTETensor *>(nvte_tensor_output_list.data()),
878-
rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], stream);
883+
rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0],
884+
nvte_tile_scheduler_workspace.data(), stream);
879885
} else {
880886
// Separate quantization for rowwise usage and columnwise usage
881887
// Rowwise quantization fusion with grouped version

0 commit comments

Comments
 (0)