Skip to content

Commit e63e17d

Browse files
committed
Formatting fixes
Signed-off-by: amitz-nv <[email protected]>
1 parent cf6f76b commit e63e17d

File tree

2 files changed

+25
-22
lines changed

2 files changed

+25
-22
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,9 @@ class FusedMoeLauncher {
305305
(int32_t)tile_tokens_dim, this->use_shuffled_weight,
306306
this->weight_layout);
307307
} else {
308-
moe_runner = std::make_unique<RunnerType>(this->mDtypeAct, this->mDtypeWeights,
309-
args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim,
310-
this->activation_type,
311-
this->use_shuffled_weight, this->weight_layout);
308+
moe_runner = std::make_unique<RunnerType>(
309+
this->mDtypeAct, this->mDtypeWeights, args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim,
310+
this->activation_type, this->use_shuffled_weight, this->weight_layout);
312311
}
313312

314313
if (moe_tactic == -1) {
@@ -417,7 +416,8 @@ class Bf16MoeLauncher : public FusedMoeLauncher {
417416
void init(std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>&& args,
418417
int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight,
419418
int64_t weight_layout) {
420-
constexpr ActivationType activation_type = ActivationType::Swiglu; // not exposed in api for now
419+
constexpr ActivationType activation_type =
420+
ActivationType::Swiglu; // not exposed in api for now
421421

422422
// Do base class init and perform common checks
423423
FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type,
@@ -532,8 +532,8 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher {
532532

533533
void init(std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>&& args,
534534
int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight,
535-
int64_t weight_layout, bool use_routing_scales_on_input_param, ActivationType activation_type) {
536-
535+
int64_t weight_layout, bool use_routing_scales_on_input_param,
536+
ActivationType activation_type) {
537537
this->use_routing_scales_on_input = use_routing_scales_on_input_param;
538538

539539
auto dtype = hidden_states.dtype();
@@ -968,8 +968,7 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher {
968968
FusedMoeLauncher::init_common(
969969
std::move(args), tile_tokens_dim, routing_method_type,
970970
/*use_shuffled_weight=*/true,
971-
static_cast<int64_t>(batchedGemm::gemm::MatrixLayout::BlockMajorK),
972-
ActivationType::Swiglu);
971+
static_cast<int64_t>(batchedGemm::gemm::MatrixLayout::BlockMajorK), ActivationType::Swiglu);
973972
}
974973

975974
void check_routing() const override { FusedMoeLauncher::check_routing_common(); }
@@ -1763,7 +1762,8 @@ Array<Tensor> trtllm_fp4_block_scale_moe(
17631762
gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar,
17641763
output2_scales_scalar, topk_ids, expert_weights);
17651764
launcher->init(std::move(args), curr_tile_N, routing_method_type, /*use_shuffled_weight=*/true,
1766-
/*weight_layout=*/0, static_cast<ActivationType>(act_type), mDtypeAct, mDtypeWeights);
1765+
/*weight_layout=*/0, static_cast<ActivationType>(act_type), mDtypeAct,
1766+
mDtypeWeights);
17671767

17681768
launchers_map[curr_tile_N] = std::move(launcher);
17691769
}

csrc/trtllm_fused_moe_runner.cu

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ static inline ActType activationTypeToGatedActType(ActivationType actType) {
201201
return ActType::GeGlu;
202202
default:
203203
FLASHINFER_CHECK(false, "Unsupported gated activation type ",
204-
serializeActivationType(actType), " of enum ", static_cast<int64_t>(actType));
204+
serializeActivationType(actType), " of enum ",
205+
static_cast<int64_t>(actType));
205206
}
206207
return ActType::SwiGlu;
207208
}
@@ -214,7 +215,8 @@ static inline EltwiseActType activationTypeToEltwiseActType(ActivationType actTy
214215
return EltwiseActType::None;
215216
default:
216217
FLASHINFER_CHECK(false, "Unsupported eltwise activation type ",
217-
serializeActivationType(actType), " of enum ", static_cast<int64_t>(actType));
218+
serializeActivationType(actType), " of enum ",
219+
static_cast<int64_t>(actType));
218220
}
219221
return EltwiseActType::None;
220222
}
@@ -224,8 +226,9 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions(
224226
ActivationType activationType, bool useShuffledMatrix,
225227
batchedGemm::gemm::MatrixLayout weightLayout) {
226228
int64_t actTypeInt = static_cast<int64_t>(activationType);
227-
FLASHINFER_CHECK(0 <= actTypeInt && actTypeInt < static_cast<int64_t>(ActivationType::InvalidType),
228-
"Unknown activation type", serializeActivationType(activationType), "of enum", actTypeInt);
229+
FLASHINFER_CHECK(
230+
0 <= actTypeInt && actTypeInt < static_cast<int64_t>(ActivationType::InvalidType),
231+
"Unknown activation type", serializeActivationType(activationType), "of enum", actTypeInt);
229232
bool isGatedAct = isGatedActivation(activationType);
230233
if (isGatedAct) {
231234
ActType actType = activationTypeToGatedActType(activationType);
@@ -289,12 +292,13 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void*
289292
auto maxNumCtasInBatchDim =
290293
Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim);
291294
int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1);
292-
mRunner.run(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, numTokens, numExperts,
293-
maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, weightsScale,
294-
expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar, outputScalesGateScalar,
295-
ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, outputScale, permutedIdxToTokenIdx,
296-
ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit,
297-
ptrNumNonExitingCtas, bmm1Workspace, stream, device, configIndex, enable_pdl);
295+
mRunner.run(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, numTokens,
296+
numExperts, maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights,
297+
weightsScale, expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar,
298+
outputScalesGateScalar, ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output,
299+
outputScale, permutedIdxToTokenIdx, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx,
300+
ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm1Workspace, stream, device,
301+
configIndex, enable_pdl);
298302
}
299303

300304
size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize,
@@ -477,8 +481,7 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace
477481
activationData.inDqSfsPtr = workspace.gemm1_output_scale;
478482
activationData.outDqSfsPtr = workspace.activation_output_scale;
479483
activationData.innerDim =
480-
args.intermediate_size *
481-
(isGatedActivation(args.activation_type) ? 2 : 1);
484+
args.intermediate_size * (isGatedActivation(args.activation_type) ? 2 : 1);
482485
activationData.topK = args.top_k;
483486
activationData.numTokens = args.num_tokens;
484487
activationData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx;

0 commit comments

Comments
 (0)