Skip to content

Commit fc42825

Browse files
committed
lint and format
Signed-off-by: Zhongbo Zhu <[email protected]>
1 parent 7ed5d9e commit fc42825

File tree

2 files changed

+8
-19
lines changed

2 files changed

+8
-19
lines changed

benchmarks/linear/benchmark_linear.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,7 @@ def benchmark_linear(
140140
label = f"{recipe_name}_{'linear'}"
141141
torch.cuda.nvtx.range_push(label)
142142
timing = benchmark.Timer(
143-
stmt=(
144-
"run_linear_multiple_steps(layer, x, mode, gradient, num_microbatches,"
145-
" recipe)"
146-
),
143+
stmt="run_linear_multiple_steps(layer, x, mode, gradient, num_microbatches, recipe)",
147144
globals={
148145
"run_linear_multiple_steps": run_linear_multiple_steps,
149146
"layer": layer,
@@ -161,9 +158,7 @@ def benchmark_linear(
161158
return timing_ms
162159

163160

164-
def run_benchmark_linear(
165-
mkns, recipe_name, use_bias, fwd_only=False
166-
):
161+
def run_benchmark_linear(mkns, recipe_name, use_bias, fwd_only=False):
167162
data = []
168163
assert not use_bias, "Bias is not supported in this benchmark script"
169164

transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,7 @@ __global__ static void row_col_rht_gemm_device(
436436
typename CLCPipeline::Params clc_pipeline_params;
437437
if (is_sched_warp) {
438438
clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer;
439-
}
440-
else {
439+
} else {
441440
clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer;
442441
}
443442
clc_pipeline_params.producer_blockid = 0;
@@ -549,9 +548,7 @@ __global__ static void row_col_rht_gemm_device(
549548
scheduler.update_work_tile_info();
550549
} while (scheduler.is_valid());
551550
mainloop_pipeline.producer_tail(mainloop_pipe_producer_state);
552-
}
553-
554-
else if (is_mma_warp) {
551+
} else if (is_mma_warp) {
555552
cutlass::arch::warpgroup_reg_dealloc<32>();
556553
if constexpr (kEnableRHTColQuant) {
557554
Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE)
@@ -615,8 +612,7 @@ __global__ static void row_col_rht_gemm_device(
615612
accumulator_pipeline.producer_tail(accumulator_pipe_producer_state);
616613
tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns);
617614
}
618-
}
619-
else if(is_sched_warp) {
615+
} else if(is_sched_warp) {
620616
cutlass::arch::warpgroup_reg_dealloc<32>();
621617
do {
622618
clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state);
@@ -627,8 +623,7 @@ __global__ static void row_col_rht_gemm_device(
627623
++clc_pipeline_consumer_state;
628624
scheduler.update_work_tile_info();
629625
} while (scheduler.is_valid());
630-
}
631-
else if (is_epilogue_col_quant_warp) {
626+
} else if (is_epilogue_col_quant_warp) {
632627
cutlass::arch::warpgroup_reg_alloc<192>();
633628
if constexpr (kEnableRHTColQuant) {
634629
using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x;
@@ -848,8 +843,7 @@ __global__ static void row_col_rht_gemm_device(
848843
scheduler.update_work_tile_info();
849844
} while (scheduler.is_valid());
850845
}
851-
}
852-
else if (is_epilogue_row_quant_warp) {
846+
} else if (is_epilogue_row_quant_warp) {
853847
cutlass::arch::warpgroup_reg_alloc<136>();
854848
if constexpr (kEnableRowQuant) {
855849
using S2RVectorType = uint128_t;
@@ -1008,7 +1002,7 @@ __global__ static void row_col_rht_gemm_device(
10081002
} else {
10091003
cutlass::arch::warpgroup_reg_dealloc<32>();
10101004
}
1011-
}
1005+
} // NOLINT(readability/fn_size)
10121006

10131007

10141008
// this function computes RHT-GEMM for

0 commit comments

Comments
 (0)