@@ -1125,8 +1125,9 @@ template <bool kEnableStochasticRounding, bool kEnableRHTColQuant, bool kEnableR
11251125void group_row_col_rht_gemm_ntt_w_sfc (int packed_sequence_length, int hidden_size, TA const *A,
11261126 TB const *B, TQA *QA, TSFA *SFA,
11271127 MultiAmaxHadamardCastFusionArgs &args,
1128- const size_t *rng_state, uint32_t sm_count,
1129- cudaStream_t stream, int k_tile_size = 1024 ) {
1128+ const size_t *rng_state, uint32_t *tile_scheduler_workspace,
1129+ uint32_t sm_count, cudaStream_t stream,
1130+ int k_tile_size = 1024 ) {
11301131 using namespace cute ;
11311132 static int constexpr SFVecSize = 16 ;
11321133 static int constexpr RhtTensorSize = 16 ;
@@ -1295,10 +1296,9 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz
12951296 NVTE_CHECK_CUDA (
12961297 cudaFuncSetAttribute (*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
12971298
1298- // Allocate workspace and set to zero
1299- void *tile_scheduler_workspace = nullptr ;
1300- NVTE_CHECK_CUDA (cudaMallocAsync (&tile_scheduler_workspace, sizeof (uint32_t ), stream));
1301- NVTE_CHECK_CUDA (cudaMemsetAsync (tile_scheduler_workspace, 0 , sizeof (uint32_t ), stream));
1299+ // Set workspace and set to zero
1300+ NVTE_CHECK_CUDA (cudaMemsetAsync (reinterpret_cast <void *>(tile_scheduler_workspace), 0 ,
1301+ sizeof (uint32_t ), stream));
13021302
13031303 // Launch kernel
13041304 cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream};
@@ -1308,8 +1308,6 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz
13081308 tile_scheduler_workspace, mma, rng_state);
13091309 NVTE_CHECK_CUDA (cudaGetLastError ());
13101310 NVTE_CHECK (status == cutlass::Status::kSuccess , " Kernel launch failed." );
1311-
1312- NVTE_CHECK_CUDA (cudaFreeAsync (tile_scheduler_workspace, stream));
13131311}
13141312
13151313} // namespace
@@ -1318,7 +1316,8 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz
13181316void group_hadamard_transform_cast_fusion (const Tensor &input_, std::vector<Tensor *> &output_list,
13191317 const size_t *split_sections, size_t num_tensors,
13201318 const Tensor &hadamard_matrix_,
1321- QuantizationConfig &quant_config, cudaStream_t stream) {
1319+ QuantizationConfig &quant_config, Tensor &quant_workspace,
1320+ cudaStream_t stream) {
13221321 NVTE_API_CALL (group_hadamard_transform_cast_fusion);
13231322
13241323 using transformer_engine::detail::kMaxTensorsPerKernel ;
@@ -1399,6 +1398,12 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector<Tens
13991398 rng_state = reinterpret_cast <const size_t *>(rng_state_tensor.data .dptr );
14001399 }
14011400
1401+ uint32_t *tile_scheduler_workspace = nullptr ;
1402+ NVTE_CHECK (quant_workspace.data .dptr != nullptr , " Quantization workspace must be provided." );
1403+ NVTE_CHECK (quant_workspace.data .buffer_size_bytes () >= sizeof (uint32_t ),
1404+ " Quantization workspace must be at least 4 bytes." );
1405+ tile_scheduler_workspace = reinterpret_cast <uint32_t *>(quant_workspace.data .dptr );
1406+
14021407 // Template arguments
14031408 using TA = cute::bfloat16_t ;
14041409 using TB = cute::bfloat16_t ;
@@ -1461,7 +1466,9 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector<Tens
14611466 /* QA=*/ reinterpret_cast <TQA *>(rowwise_data_base_ptr),
14621467 /* SFA=*/ reinterpret_cast <TSFA *>(rowwise_scale_inv_base_ptr),
14631468 /* args=*/ kernel_args,
1464- /* rng_state=*/ rng_state, /* sm_count=*/ sm_count,
1469+ /* rng_state=*/ rng_state,
1470+ /* tile_scheduler_workspace=*/ tile_scheduler_workspace,
1471+ /* sm_count=*/ sm_count,
14651472 /* stream=*/ stream, /* k_tile_size=*/ k_tile_size);
14661473 } else {
14671474 NVTE_ERROR (" Invalid kernel configuration (kEnableRHTColQuant=" ,
@@ -1478,7 +1485,7 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso
14781485 const size_t *split_sections,
14791486 const size_t num_tensors,
14801487 const NVTEQuantizationConfig quant_config,
1481- cudaStream_t stream) {
1488+ NVTETensor quant_workspace, cudaStream_t stream) {
14821489 NVTE_API_CALL (nvte_group_hadamard_transform_cast_fusion);
14831490 using namespace transformer_engine ;
14841491 NVTE_CHECK (num_tensors > 0 , " Number of tensors should be greater than 0." );
@@ -1489,6 +1496,8 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso
14891496 output_list[i] = convertNVTETensorCheck (outputs[i]);
14901497 }
14911498
1499+ Tensor *quant_workspace_tensor = convertNVTETensorCheck (quant_workspace);
1500+
14921501 QuantizationConfig quant_config_cpp;
14931502 if (quant_config != nullptr ) {
14941503 quant_config_cpp = *reinterpret_cast <QuantizationConfig *>(quant_config);
@@ -1497,5 +1506,5 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso
14971506 // Call the multi-tensor Hadamard transform amax implementation.
14981507 group_hadamard_transform_cast_fusion (*input_tensor, output_list, split_sections, num_tensors,
14991508 *convertNVTETensorCheck (hadamard_matrix), quant_config_cpp,
1500- stream);
1509+ *quant_workspace_tensor, stream);
15011510}
0 commit comments