@@ -436,8 +436,7 @@ __global__ static void row_col_rht_gemm_device(
436436 typename CLCPipeline::Params clc_pipeline_params;
437437 if (is_sched_warp) {
438438 clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer;
439- }
440- else {
439+ } else {
441440 clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer;
442441 }
443442 clc_pipeline_params.producer_blockid = 0 ;
@@ -549,9 +548,7 @@ __global__ static void row_col_rht_gemm_device(
549548 scheduler.update_work_tile_info ();
550549 } while (scheduler.is_valid ());
551550 mainloop_pipeline.producer_tail (mainloop_pipe_producer_state);
552- }
553-
554- else if (is_mma_warp) {
551+ } else if (is_mma_warp) {
555552 cutlass::arch::warpgroup_reg_dealloc<32 >();
556553 if constexpr (kEnableRHTColQuant ) {
557554 Tensor tCsA = make_tensor (make_smem_ptr (shared_storage.tensors .smem_A .data ()), sAlayout ); // (MMA,MMA_M,MMA_N,PIPE)
@@ -615,8 +612,7 @@ __global__ static void row_col_rht_gemm_device(
615612 accumulator_pipeline.producer_tail (accumulator_pipe_producer_state);
616613 tmem_allocator.free (tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns);
617614 }
618- }
619- else if (is_sched_warp) {
615+ } else if (is_sched_warp) {
620616 cutlass::arch::warpgroup_reg_dealloc<32 >();
621617 do {
622618 clc_throttle_pipeline.consumer_wait (clc_pipe_throttle_consumer_state);
@@ -627,8 +623,7 @@ __global__ static void row_col_rht_gemm_device(
627623 ++clc_pipeline_consumer_state;
628624 scheduler.update_work_tile_info ();
629625 } while (scheduler.is_valid ());
630- }
631- else if (is_epilogue_col_quant_warp) {
626+ } else if (is_epilogue_col_quant_warp) {
632627 cutlass::arch::warpgroup_reg_alloc<192 >();
633628 if constexpr (kEnableRHTColQuant ) {
634629 using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x;
@@ -848,8 +843,7 @@ __global__ static void row_col_rht_gemm_device(
848843 scheduler.update_work_tile_info ();
849844 } while (scheduler.is_valid ());
850845 }
851- }
852- else if (is_epilogue_row_quant_warp) {
846+ } else if (is_epilogue_row_quant_warp) {
853847 cutlass::arch::warpgroup_reg_alloc<136 >();
854848 if constexpr (kEnableRowQuant ) {
855849 using S2RVectorType = uint128_t ;
@@ -1008,7 +1002,7 @@ __global__ static void row_col_rht_gemm_device(
10081002 } else {
10091003 cutlass::arch::warpgroup_reg_dealloc<32 >();
10101004 }
1011- }
1005+ } // NOLINT(readability/fn_size)
10121006
10131007
10141008// this function computes RHT-GEMM for
0 commit comments