@@ -201,7 +201,8 @@ static inline ActType activationTypeToGatedActType(ActivationType actType) {
201201 return ActType::GeGlu;
202202 default:
203203 FLASHINFER_CHECK(false, "Unsupported gated activation type ",
204- serializeActivationType(actType), " of enum ", static_cast<int64_t>(actType));
204+ serializeActivationType(actType), " of enum ",
205+ static_cast<int64_t>(actType));
205206 }
206207 return ActType::SwiGlu;
207208}
@@ -214,7 +215,8 @@ static inline EltwiseActType activationTypeToEltwiseActType(ActivationType actTy
214215 return EltwiseActType::None;
215216 default:
216217 FLASHINFER_CHECK(false, "Unsupported eltwise activation type ",
217- serializeActivationType(actType), " of enum ", static_cast<int64_t>(actType));
218+ serializeActivationType(actType), " of enum ",
219+ static_cast<int64_t>(actType));
218220 }
219221 return EltwiseActType::None;
220222}
@@ -224,8 +226,9 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions(
224226 ActivationType activationType, bool useShuffledMatrix,
225227 batchedGemm::gemm::MatrixLayout weightLayout) {
226228 int64_t actTypeInt = static_cast<int64_t>(activationType);
227- FLASHINFER_CHECK(0 <= actTypeInt && actTypeInt < static_cast<int64_t>(ActivationType::InvalidType),
228- "Unknown activation type", serializeActivationType(activationType), "of enum", actTypeInt);
229+ FLASHINFER_CHECK(
230+ 0 <= actTypeInt && actTypeInt < static_cast<int64_t>(ActivationType::InvalidType),
231+ "Unknown activation type", serializeActivationType(activationType), "of enum", actTypeInt);
229232 bool isGatedAct = isGatedActivation(activationType);
230233 if (isGatedAct) {
231234 ActType actType = activationTypeToGatedActType(activationType);
@@ -289,12 +292,13 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void*
289292 auto maxNumCtasInBatchDim =
290293 Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim);
291294 int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1);
292- mRunner.run(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, numTokens, numExperts,
293- maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, weightsScale,
294- expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar, outputScalesGateScalar,
295- ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, outputScale, permutedIdxToTokenIdx,
296- ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit,
297- ptrNumNonExitingCtas, bmm1Workspace, stream, device, configIndex, enable_pdl);
295+ mRunner.run(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, numTokens,
296+ numExperts, maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights,
297+ weightsScale, expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar,
298+ outputScalesGateScalar, ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output,
299+ outputScale, permutedIdxToTokenIdx, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx,
300+ ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm1Workspace, stream, device,
301+ configIndex, enable_pdl);
298302}
299303
300304size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize,
@@ -477,8 +481,7 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace
477481 activationData.inDqSfsPtr = workspace.gemm1_output_scale;
478482 activationData.outDqSfsPtr = workspace.activation_output_scale;
479483 activationData.innerDim =
480- args.intermediate_size *
481- (isGatedActivation(args.activation_type) ? 2 : 1);
484+ args.intermediate_size * (isGatedActivation(args.activation_type) ? 2 : 1);
482485 activationData.topK = args.top_k;
483486 activationData.numTokens = args.num_tokens;
484487 activationData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx;
0 commit comments