@@ -36,6 +36,25 @@ using namespace batchedGemm::trtllm::gen;
3636
3737static BatchedGemmInterface::ModuleCache globalTrtllmGenBatchedGemmModuleCache;
3838
39+ constexpr bool isSMCompatible (int gpuSM, SmVersion kernelSM)
40+ {
41+ if (gpuSM == 103 )
42+ {
43+ return kernelSM == SmVersion::Sm100f || kernelSM == SmVersion::Sm103a;
44+ }
45+ else if (gpuSM == 100 )
46+ {
47+ return kernelSM == SmVersion::Sm100f || kernelSM == SmVersion::Sm100a;
48+ }
49+ else if (gpuSM == 90 )
50+ {
51+ return kernelSM == SmVersion::Sm90a;
52+ }
53+
54+ TLLM_THROW (" Unexpected gpuSM %d" , gpuSM);
55+ return false ;
56+ }
57+
3958std::vector<int64_t > prioritizePredefinedConfigs (int m, int n, int k, std::vector<int64_t > const & sortedIndices,
4059 batchedGemm::batchedGemm::BatchedGemmConfig const * configs)
4160{
@@ -98,6 +117,7 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
98117
99118 mPassingConfigIndices .clear ();
100119
120+ int gpuSM = tensorrt_llm::common::getSMVersion ();
101121 for (size_t i = 0 ; i < bmm.getNumBatchedGemmConfigs (); ++i)
102122 {
103123 auto const options = configs[i].mOptions ;
@@ -108,7 +128,7 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
108128 && options.mTransposeMmaOutput == mOptions .transposeMmaOutput
109129 && (!doesRouteImplUseNoRoute (options.mRouteImpl )) == mOptions .routeAct
110130 && options.mFusedAct == mOptions .fusedAct && options.mIsStaticBatch == mOptions .staticBatch
111- && tileSize == mOptions .tileSize )
131+ && tileSize == mOptions .tileSize && isSMCompatible (gpuSM, configs[i]. mSm ) )
112132 {
113133 auto sm = configs[i].mSm ;
114134 if (sm != SmVersion::Sm100f)
0 commit comments