1414 * See the License for the specific language governing permissions and
1515 * limitations under the License.
1616 */
17+ #include " cublasScaledMMLut.h"
1718#include " tensorrt_llm/common/cublasMMWrapper.h"
1819#include " tensorrt_llm/common/cudaUtils.h"
1920#include " tensorrt_llm/kernels/userbuffers/ub_interface.h"
2223#include " tensorrt_llm/runtime/torchUtils.h"
2324#include " tensorrt_llm/thop/thUtils.h"
2425#include " userbuffersTensor.h"
25- #include < array>
2626#include < cublasLt.h>
2727#include < torch/extension.h>
28- #include < unordered_map>
2928
3029using torch::Tensor;
3130
@@ -39,67 +38,7 @@ namespace
3938
4039using tensorrt_llm::common::check;
4140using tensorrt_llm::common::CublasMMWrapper;
42-
43- struct hash_tuple
44- {
45- size_t operator ()(std::tuple<int , int , int > const & x) const
46- {
47- return std::get<0 >(x) ^ std::get<1 >(x) ^ std::get<2 >(x);
48- }
49- };
50-
51- // got from cublasTest matmultFind
52- // {mp2, k, n}: {algo, m_tile, m_stages, m_numsK, m_reduction, m_swizzle, m_custom, m_cga}
53- using AlgoListType = std::unordered_map<std::tuple<int32_t , int32_t , int32_t >, std::array<int , 8 >, hash_tuple>;
54-
55- // bf16*bf16->fp32->bf16
56- AlgoListType spark_bf16_algo_list = {
57- // GPT-OSS-20b
58- // -m201088 -n1 -algo21 -m_tile11 -m_stages20 -m_workmem0 -k2880
59- {{8 , 2880 , 201088 }, {21 , 11 , 20 , 1 , 0 , 0 , 0 , 0 }},
60- // -m32 -n1 -algo14 -m_reduction2 -m_numsK10 -m_workmem1024 -k2880
61- {{8 , 2880 , 32 }, {14 , 0 , 0 , 10 , 2 , 0 , 0 , 0 }},
62- // -m32 -n2048 -algo21 -m_tile11 -m_stages13 -m_reduction1 -m_numsK9 -m_workmem1024
63- // -k2880
64- {{2048 , 2880 , 32 }, {21 , 11 , 13 , 9 , 1 , 0 , 0 , 0 }},
65- // -m32 -n2175 -algo21 -m_tile11 -m_stages19 -m_reduction1 -m_numsK11
66- // -m_workmem1024 -k2880
67- {{4096 , 2880 , 32 }, {21 , 11 , 19 , 11 , 1 , 0 , 0 , 0 }},
68- // -m5120 -n1 -algo23 -m_tile11 -m_stages8 -m_reduction1 -m_numsK2
69- // -m_workmem1024 -k2880
70- {{8 , 2880 , 5120 }, {23 , 11 , 8 , 2 , 1 , 0 , 0 , 0 }},
71- // -m5120 -n2048 -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k2880
72- {{2048 , 2880 , 5120 }, {21 , 20 , 15 , 1 , 0 , 0 , 0 , 0 }},
73- // -m5120 -n2175 -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k2880
74- {{4096 , 2880 , 5120 }, {21 , 20 , 15 , 1 , 0 , 0 , 0 , 0 }},
75- // -m2880 -n1 -algo23 -m_tile11 -m_stages14 -m_reduction1 -m_numsK24 -m_workmem1024 -k4096
76- {{8 , 4096 , 2880 }, {23 , 11 , 14 , 24 , 1 , 0 , 0 , 0 }},
77- // -m2880 -n2048 -ldc2880 -Poutt -ldd2880 -Ps -Pscales -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k4096
78- {{2048 , 4096 , 2880 }, {21 , 20 , 15 , 1 , 0 , 0 , 0 , 0 }},
79- // -m2880 -n2175 -ldc2880 -Poutt -ldd2880 -Ps -Pscales -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k4096
80- {{4096 , 4096 , 2880 }, {21 , 20 , 15 , 1 , 0 , 0 , 0 , 0 }},
81- };
82-
83- // bf16*bf16->fp32->bf16
84- AlgoListType bf16_algo_list = {
85- // Deepseek v3/R1 router gemm
86- // [-algo66 -m_tile10 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom3 -m_mma0 -m_cga2 -m_scheduling1]
87- {{8 , 7168 , 256 }, {66 , 10 , 35 , 1 , 0 , 0 , 3 , 2 }},
88- {{512 , 7168 , 256 }, {66 , 48 , 35 , 1 , 0 , 0 , 0 , 2 }},
89- {{1024 , 7168 , 256 }, {66 , 13 , 35 , 1 , 0 , 0 , 1 , 3 }},
90- };
91-
92- // fp8*fp8->fp32->fp16
93- AlgoListType fp8_algo_list = {
94- // Llama-3.1-70B
95- // [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom5 -m_mma0 -m_cga2 -m_scheduling1]
96- {{8 , 8192 , 8192 }, {66 , 393 , 36 , 1 , 0 , 0 , 5 , 2 }},
97- // [-algo66 -m_tile10 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga2 -m_scheduling1]
98- {{8 , 8192 , 57344 }, {66 , 10 , 36 , 1 , 0 , 0 , 1 , 2 }},
99- // Llama-3.3-70B TP4 (this is the default algo on B200. Here we aim to use the same algo on GB200.)
100- // [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga4 -m_scheduling1]
101- {{8 , 8192 , 14336 }, {66 , 393 , 36 , 1 , 0 , 1 , 1 , 4 }},
102- };
41+ using cublas_lut::AlgoListType;
10342
10443void set_algo_attr (cublasLtMatmulAlgo_t& algo, std::array<int , 8 > const & attr_list)
10544{
@@ -127,29 +66,31 @@ bool find_special_algo(cublasLtMatmulAlgo_t& algo, std::shared_ptr<CublasMMWrapp
12766 cudaDataType_t bType, cudaDataType_t outType)
12867{
12968 int32_t mp2 = std::max (nextPowerOfTwo (m), 8 );
130- AlgoListType algo_list;
69+ AlgoListType const * algo_list = nullptr ;
13170 if ((aType == CUDA_R_16BF || aType == CUDA_R_16F) && (outType == aType || outType == CUDA_R_32F)
13271 && compType == CUBLAS_COMPUTE_32F)
13372 {
13473 // TODO: remove this after cublas fix the heuristic for Spark
135- algo_list = tensorrt_llm::common::getSMVersion (/* queryRealSmArch=*/ true ) == 121 ? spark_bf16_algo_list
136- : bf16_algo_list;
74+ algo_list = tensorrt_llm::common::getSMVersion (/* queryRealSmArch=*/ true ) == 121
75+ ? &cublas_lut::spark_bf16_algo_list
76+ : &cublas_lut::bf16_algo_list;
13777 }
13878 else if (aType == CUDA_R_8F_E4M3 && compType == CUBLAS_COMPUTE_32F)
13979 {
140- algo_list = fp8_algo_list;
80+ algo_list = &cublas_lut:: fp8_algo_list;
14181 }
14282 else
14383 {
14484 TLLM_LOG_DEBUG (
14585 " No special cublasLt algo found for aType=%d, outType=%d, compType=%d\n " , aType, outType, compType);
14686 return false ;
14787 }
148- if (auto algo_iter = algo_list. find ({mp2, k, n}); algo_iter != algo_list. end ())
88+ if (auto algo_iter = algo_list-> find ({mp2, k, n}); algo_iter != algo_list-> end ())
14989 {
15090 int const algoID = algo_iter->second [0 ];
15191 check_cuda_error (cublasLtMatmulAlgoInit (
15292 cublasWrapper->getCublasLtHandle (), compType, scaleType, aType, bType, outType, outType, algoID, &algo));
93+ TLLM_LOG_DEBUG (" Found special cublasLt algo for m=%d, k=%d, n=%d\n " , m, k, n);
15394 set_algo_attr (algo, algo_iter->second );
15495 }
15596 else
0 commit comments