Skip to content

Commit 98d72c7

Browse files
authored
[None][feat] spark cublas LUT table for llama-8b-bf16 perf (NVIDIA#9811)
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
1 parent e4e0986 commit 98d72c7

File tree

4 files changed

+122
-69
lines changed

4 files changed

+122
-69
lines changed

cpp/tensorrt_llm/thop/cublasScaledMM.cpp

Lines changed: 9 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17+
#include "cublasScaledMMLut.h"
1718
#include "tensorrt_llm/common/cublasMMWrapper.h"
1819
#include "tensorrt_llm/common/cudaUtils.h"
1920
#include "tensorrt_llm/kernels/userbuffers/ub_interface.h"
@@ -22,10 +23,8 @@
2223
#include "tensorrt_llm/runtime/torchUtils.h"
2324
#include "tensorrt_llm/thop/thUtils.h"
2425
#include "userbuffersTensor.h"
25-
#include <array>
2626
#include <cublasLt.h>
2727
#include <torch/extension.h>
28-
#include <unordered_map>
2928

3029
using torch::Tensor;
3130

@@ -39,67 +38,7 @@ namespace
3938

4039
using tensorrt_llm::common::check;
4140
using tensorrt_llm::common::CublasMMWrapper;
42-
43-
struct hash_tuple
44-
{
45-
size_t operator()(std::tuple<int, int, int> const& x) const
46-
{
47-
return std::get<0>(x) ^ std::get<1>(x) ^ std::get<2>(x);
48-
}
49-
};
50-
51-
// got from cublasTest matmultFind
52-
// {mp2, k, n}: {algo, m_tile, m_stages, m_numsK, m_reduction, m_swizzle, m_custom, m_cga}
53-
using AlgoListType = std::unordered_map<std::tuple<int32_t, int32_t, int32_t>, std::array<int, 8>, hash_tuple>;
54-
55-
// bf16*bf16->fp32->bf16
56-
AlgoListType spark_bf16_algo_list = {
57-
// GPT-OSS-20b
58-
//-m201088 -n1 -algo21 -m_tile11 -m_stages20 -m_workmem0 -k2880
59-
{{8, 2880, 201088}, {21, 11, 20, 1, 0, 0, 0, 0}},
60-
//-m32 -n1 -algo14 -m_reduction2 -m_numsK10 -m_workmem1024 -k2880
61-
{{8, 2880, 32}, {14, 0, 0, 10, 2, 0, 0, 0}},
62-
//-m32 -n2048 -algo21 -m_tile11 -m_stages13 -m_reduction1 -m_numsK9 -m_workmem1024
63-
//-k2880
64-
{{2048, 2880, 32}, {21, 11, 13, 9, 1, 0, 0, 0}},
65-
//-m32 -n2175 -algo21 -m_tile11 -m_stages19 -m_reduction1 -m_numsK11
66-
//-m_workmem1024 -k2880
67-
{{4096, 2880, 32}, {21, 11, 19, 11, 1, 0, 0, 0}},
68-
//-m5120 -n1 -algo23 -m_tile11 -m_stages8 -m_reduction1 -m_numsK2
69-
//-m_workmem1024 -k2880
70-
{{8, 2880, 5120}, {23, 11, 8, 2, 1, 0, 0, 0}},
71-
//-m5120 -n2048 -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k2880
72-
{{2048, 2880, 5120}, {21, 20, 15, 1, 0, 0, 0, 0}},
73-
//-m5120 -n2175 -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k2880
74-
{{4096, 2880, 5120}, {21, 20, 15, 1, 0, 0, 0, 0}},
75-
//-m2880 -n1 -algo23 -m_tile11 -m_stages14 -m_reduction1 -m_numsK24 -m_workmem1024 -k4096
76-
{{8, 4096, 2880}, {23, 11, 14, 24, 1, 0, 0, 0}},
77-
//-m2880 -n2048 -ldc2880 -Poutt -ldd2880 -Ps -Pscales -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k4096
78-
{{2048, 4096, 2880}, {21, 20, 15, 1, 0, 0, 0, 0}},
79-
//-m2880 -n2175 -ldc2880 -Poutt -ldd2880 -Ps -Pscales -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k4096
80-
{{4096, 4096, 2880}, {21, 20, 15, 1, 0, 0, 0, 0}},
81-
};
82-
83-
// bf16*bf16->fp32->bf16
84-
AlgoListType bf16_algo_list = {
85-
// Deepseek v3/R1 router gemm
86-
// [-algo66 -m_tile10 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom3 -m_mma0 -m_cga2 -m_scheduling1]
87-
{{8, 7168, 256}, {66, 10, 35, 1, 0, 0, 3, 2}},
88-
{{512, 7168, 256}, {66, 48, 35, 1, 0, 0, 0, 2}},
89-
{{1024, 7168, 256}, {66, 13, 35, 1, 0, 0, 1, 3}},
90-
};
91-
92-
// fp8*fp8->fp32->fp16
93-
AlgoListType fp8_algo_list = {
94-
// Llama-3.1-70B
95-
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom5 -m_mma0 -m_cga2 -m_scheduling1]
96-
{{8, 8192, 8192}, {66, 393, 36, 1, 0, 0, 5, 2}},
97-
// [-algo66 -m_tile10 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga2 -m_scheduling1]
98-
{{8, 8192, 57344}, {66, 10, 36, 1, 0, 0, 1, 2}},
99-
// Llama-3.3-70B TP4 (this is the default algo on B200. Here we aim to use the same algo on GB200.)
100-
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga4 -m_scheduling1]
101-
{{8, 8192, 14336}, {66, 393, 36, 1, 0, 1, 1, 4}},
102-
};
41+
using cublas_lut::AlgoListType;
10342

10443
void set_algo_attr(cublasLtMatmulAlgo_t& algo, std::array<int, 8> const& attr_list)
10544
{
@@ -127,29 +66,31 @@ bool find_special_algo(cublasLtMatmulAlgo_t& algo, std::shared_ptr<CublasMMWrapp
12766
cudaDataType_t bType, cudaDataType_t outType)
12867
{
12968
int32_t mp2 = std::max(nextPowerOfTwo(m), 8);
130-
AlgoListType algo_list;
69+
AlgoListType const* algo_list = nullptr;
13170
if ((aType == CUDA_R_16BF || aType == CUDA_R_16F) && (outType == aType || outType == CUDA_R_32F)
13271
&& compType == CUBLAS_COMPUTE_32F)
13372
{
13473
// TODO: remove this after cublas fix the heuristic for Spark
135-
algo_list = tensorrt_llm::common::getSMVersion(/*queryRealSmArch=*/true) == 121 ? spark_bf16_algo_list
136-
: bf16_algo_list;
74+
algo_list = tensorrt_llm::common::getSMVersion(/*queryRealSmArch=*/true) == 121
75+
? &cublas_lut::spark_bf16_algo_list
76+
: &cublas_lut::bf16_algo_list;
13777
}
13878
else if (aType == CUDA_R_8F_E4M3 && compType == CUBLAS_COMPUTE_32F)
13979
{
140-
algo_list = fp8_algo_list;
80+
algo_list = &cublas_lut::fp8_algo_list;
14181
}
14282
else
14383
{
14484
TLLM_LOG_DEBUG(
14585
"No special cublasLt algo found for aType=%d, outType=%d, compType=%d\n", aType, outType, compType);
14686
return false;
14787
}
148-
if (auto algo_iter = algo_list.find({mp2, k, n}); algo_iter != algo_list.end())
88+
if (auto algo_iter = algo_list->find({mp2, k, n}); algo_iter != algo_list->end())
14989
{
15090
int const algoID = algo_iter->second[0];
15191
check_cuda_error(cublasLtMatmulAlgoInit(
15292
cublasWrapper->getCublasLtHandle(), compType, scaleType, aType, bType, outType, outType, algoID, &algo));
93+
TLLM_LOG_DEBUG("Found special cublasLt algo for m=%d, k=%d, n=%d\n", m, k, n);
15394
set_algo_attr(algo, algo_iter->second);
15495
}
15596
else
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#pragma once
18+
19+
#include <array>
20+
#include <cstddef>
21+
#include <cstdint>
22+
#include <tuple>
23+
#include <unordered_map>
24+
25+
namespace torch_ext
26+
{
27+
namespace cublas_lut
28+
{
29+
30+
struct HashTuple
31+
{
32+
size_t operator()(std::tuple<int32_t, int32_t, int32_t> const& x) const
33+
{
34+
return std::get<0>(x) ^ std::get<1>(x) ^ std::get<2>(x);
35+
}
36+
};
37+
38+
// {mp2, k, n}: {algo, m_tile, m_stages, m_numsK, m_reduction, m_swizzle, m_custom, m_cga}
39+
using AlgoListType = std::unordered_map<std::tuple<int32_t, int32_t, int32_t>, std::array<int, 8>, HashTuple>;
40+
41+
inline const AlgoListType spark_bf16_algo_list = {
42+
// llama 8b instruct fp16 decode
43+
// [-algo67 -m_tile6 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom130 -m_mma0 -m_cga2 -m_scheduling1]
44+
{{8, 4096, 4096}, {67, 6, 35, 1, 0, 0, 130, 2}},
45+
// [-algo67 -m_tile393 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom142 -m_mma0 -m_cga2 -m_scheduling1]
46+
{{8, 4096, 6144}, {67, 393, 35, 1, 0, 0, 142, 2}},
47+
// [-algo67 -m_tile393 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom142 -m_mma0 -m_cga2 -m_scheduling1]
48+
{{8, 4096, 128256}, {67, 393, 35, 1, 0, 0, 142, 2}},
49+
50+
// gpt-oss mxfp4-fp16 decode
51+
// [-algo67 -m_tile393 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom142 -m_mma0 -m_cga2 -m_scheduling1]
52+
{{8, 2880, 201088}, {67, 393, 35, 1, 0, 0, 142, 2}},
53+
// [-algo14 -m_tile0 -m_stages35 -m_numsK10 -m_reduction2 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
54+
{{8, 2880, 32}, {14, 0, 0, 10, 2, 0, 0, 0}},
55+
// [-algo21 -m_tile11 -m_stages13 -m_numsK9 -m_reduction1 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
56+
//-k2880
57+
{{2048, 2880, 32}, {21, 11, 13, 9, 1, 0, 0, 0}},
58+
// [-algo21 -m_tile11 -m_stages19 -m_numsK11 -m_reduction1 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
59+
//-m_workmem1024 -k2880
60+
{{4096, 2880, 32}, {21, 11, 19, 11, 1, 0, 0, 0}},
61+
// [-algo23 -m_tile11 -m_stages8 -m_numsK2 -m_reduction1 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
62+
//-m_workmem1024 -k2880
63+
{{8, 2880, 5120}, {23, 11, 8, 2, 1, 0, 0, 0}},
64+
// [-algo21 -m_tile20 -m_stages15 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
65+
{{2048, 2880, 5120}, {21, 20, 15, 1, 0, 0, 0, 0}},
66+
// [-algo21 -m_tile20 -m_stages15 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
67+
{{4096, 2880, 5120}, {21, 20, 15, 1, 0, 0, 0, 0}},
68+
// [-algo23 -m_tile11 -m_stages14 -m_numsK24 -m_reduction1 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
69+
{{8, 4096, 2880}, {23, 11, 14, 24, 1, 0, 0, 0}},
70+
// [-algo21 -m_tile20 -m_stages15 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
71+
{{2048, 4096, 2880}, {21, 20, 15, 1, 0, 0, 0, 0}},
72+
// [-algo21 -m_tile20 -m_stages15 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
73+
{{4096, 4096, 2880}, {21, 20, 15, 1, 0, 0, 0, 0}},
74+
75+
};
76+
77+
// bf16*bf16->fp32->bf16
78+
inline const AlgoListType bf16_algo_list = {
79+
// Deepseek v3/R1 router gemm
80+
// [-algo66 -m_tile10 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom3 -m_mma0 -m_cga2 -m_scheduling1]
81+
{{8, 7168, 256}, {66, 10, 35, 1, 0, 0, 3, 2}},
82+
{{512, 7168, 256}, {66, 48, 35, 1, 0, 0, 0, 2}},
83+
{{1024, 7168, 256}, {66, 13, 35, 1, 0, 0, 1, 3}},
84+
};
85+
86+
// fp8*fp8->fp32->fp16
87+
inline const AlgoListType fp8_algo_list = {
88+
// Llama-3.1-70B
89+
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom5 -m_mma0 -m_cga2 -m_scheduling1]
90+
{{8, 8192, 8192}, {66, 393, 36, 1, 0, 0, 5, 2}},
91+
// [-algo66 -m_tile10 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga2 -m_scheduling1]
92+
{{8, 8192, 57344}, {66, 10, 36, 1, 0, 0, 1, 2}},
93+
// Llama-3.3-70B TP4 (this is the default algo on B200. Here we aim to use the same algo on GB200.)
94+
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga4 -m_scheduling1]
95+
{{8, 8192, 14336}, {66, 393, 36, 1, 0, 1, 1, 4}},
96+
};
97+
98+
} // namespace cublas_lut
99+
} // namespace torch_ext

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def __init__(
230230
self,
231231
model_config: ModelConfig[LlamaConfig],
232232
layer_idx: Optional[int] = None,
233+
use_custom_cublas_mm: bool = False,
233234
):
234235
config = model_config.pretrained_config
235236
super().__init__(
@@ -245,6 +246,7 @@ def __init__(
245246
layer_idx=layer_idx,
246247
dtype=config.torch_dtype,
247248
config=model_config,
249+
use_custom_cublas_mm=use_custom_cublas_mm,
248250
)
249251

250252

@@ -618,6 +620,7 @@ def __init__(
618620
self,
619621
model_config: ModelConfig[LlamaConfig],
620622
layer_idx: int,
623+
use_custom_cublas_mm: bool = False,
621624
) -> Tuple[torch.Tensor, torch.Tensor]:
622625
super().__init__()
623626
config = model_config.pretrained_config
@@ -634,6 +637,7 @@ def __init__(
634637
self.self_attn = LlamaAttention(
635638
model_config,
636639
layer_idx=layer_idx,
640+
use_custom_cublas_mm=use_custom_cublas_mm,
637641
)
638642

639643
self.mlp = GatedMLP(
@@ -643,6 +647,7 @@ def __init__(
643647
dtype=config.torch_dtype,
644648
config=model_config,
645649
layer_idx=layer_idx,
650+
use_custom_cublas_mm=use_custom_cublas_mm,
646651
)
647652
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
648653
eps=config.rms_norm_eps,
@@ -889,6 +894,8 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
889894
config = self.model_config.pretrained_config
890895
self.num_hidden_layers = config.num_hidden_layers
891896

897+
self.use_custom_cublas_mm = get_sm_version() == 121
898+
892899
vocab_size = config.vocab_size
893900
# TODO smor- we load manually only if there is a single lora dir, need to come up with a better solution
894901
self.has_custom_embed_tokens = False
@@ -909,6 +916,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
909916
vocab_size,
910917
config.hidden_size,
911918
dtype=config.torch_dtype,
919+
use_custom_cublas_mm=self.use_custom_cublas_mm,
912920
)
913921
else:
914922
self.embed_tokens = Embedding(
@@ -918,6 +926,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
918926
mapping=model_config.mapping,
919927
tensor_parallel_mode=TensorParallelMode.COLUMN,
920928
gather_output=True,
929+
use_custom_cublas_mm=self.use_custom_cublas_mm,
921930
)
922931

923932
if self.has_custom_embed_tokens:
@@ -932,7 +941,8 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
932941
self.embed_tokens.weight.data.copy_(x)
933942

934943
self.layers = nn.ModuleList([
935-
LlamaDecoderLayer(model_config, layer_idx)
944+
LlamaDecoderLayer(model_config, layer_idx,
945+
self.use_custom_cublas_mm)
936946
for layer_idx in range(config.num_hidden_layers)
937947
])
938948
self.norm = RMSNorm(hidden_size=config.hidden_size,

tensorrt_llm/_torch/modules/gated_mlp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
layer_idx: Optional[int] = None,
3333
use_cute_dsl_blockscaling_mm: bool = False,
3434
disable_deep_gemm: bool = False,
35+
use_custom_cublas_mm: bool = False,
3536
):
3637

3738
super().__init__()
@@ -83,6 +84,7 @@ def __init__(
8384
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
8485
disable_deep_gemm=disable_deep_gemm,
8586
fused_weight_shard_indices_mapping=gateup_shard_indices_mapping,
87+
use_custom_cublas_mm=use_custom_cublas_mm,
8688
)
8789

8890
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
@@ -103,6 +105,7 @@ def __init__(
103105
force_dynamic_quantization=config.force_dynamic_quantization,
104106
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
105107
disable_deep_gemm=disable_deep_gemm,
108+
use_custom_cublas_mm=use_custom_cublas_mm,
106109
)
107110

108111
# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,

0 commit comments

Comments
 (0)