Skip to content

Commit e49c70f

Browse files
authored
[None][feat] Support Mistral Large3 LLM part (NVIDIA#9820)
Signed-off-by: bhsueh <[email protected]>
1 parent 98d72c7 commit e49c70f

File tree

20 files changed

+946
-69
lines changed

20 files changed

+946
-69
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ def add_llm_args(parser):
2323
type=str,
2424
nargs="+",
2525
help="A single or a list of text prompts.")
26+
parser.add_argument('--checkpoint_format',
27+
type=str,
28+
default=None,
29+
choices=["HF", "mistral"],
30+
help="Model checkpoint format.")
2631
# Build config
2732
parser.add_argument("--max_seq_len",
2833
type=int,
@@ -237,6 +242,7 @@ def setup_llm(args, **kwargs):
237242
llm = LLM(
238243
model=args.model_dir,
239244
backend='pytorch',
245+
checkpoint_format=args.checkpoint_format,
240246
disable_overlap_scheduler=args.disable_overlap_scheduler,
241247
kv_cache_config=kv_cache_config,
242248
attn_backend=args.attention_backend,
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Mistral Large V3
2+
3+
* Setup the model path
4+
5+
```bash
6+
export mistral_large_3_model_path=<mistral_large_3_model_path>
7+
```
8+
9+
## LLM-only run
10+
11+
* Run the Mistral Large V3 by `quickstart_advanced.py`
12+
13+
```bash
14+
mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickstart_advanced.py \
15+
--model_dir ${mistral_large_3_model_path} \
16+
--tp_size 4 \
17+
--moe_ep_size 4 \
18+
--max_tokens 100 \
19+
--checkpoint_format mistral \
20+
--moe_backend TRTLLM
21+
```
22+
23+
* Launch the trtllm-serve and send a request
24+
25+
```bash
26+
echo "
27+
backend: pytorch
28+
tensor_parallel_size: 4
29+
moe_expert_parallel_size: 4
30+
enable_attention_dp: false
31+
kv_cache_config:
32+
enable_block_reuse: true
33+
checkpoint_format: mistral
34+
" > serve.yml
35+
mpirun -n 1 --allow-run-as-root --oversubscribe python3 -m tensorrt_llm.commands.serve serve \
36+
${mistral_large_3_model_path} \
37+
--host localhost --port 8001 --backend pytorch \
38+
--extra_llm_api_options serve.yml \
39+
--tokenizer ${mistral_large_3_model_path} \
40+
2>&1 | tee serve_debug.log &
41+
42+
curl http://localhost:8001/v1/completions \
43+
-H "Content-Type: application/json" \
44+
-d '{
45+
"model": "${mistral_large_3_model_path}",
46+
"prompt": "The capital of France is",
47+
"max_tokens": 16,
48+
"top_k": 16
49+
}'
50+
51+
# The result would be like
52+
{"id":"cmpl-7e342c1d722d4226a1bf3ed35d762c35","object":"text_completion","created":1764061351,"model":"${mistral_large_3_model_path}","choices":[{"index":0,"text":"The capital of France is **Paris**.\n\nParis is the largest city in France and","token_ids":null,"logprobs":null,"context_logits":null,"finish_reason":"length","stop_reason":null,"disaggregated_params":null,"avg_decoded_tokens_per_iter":1.0}],"usage":{"prompt_tokens":7,"total_tokens":23,"completion_tokens":16,"prompt_tokens_details":{"cached_tokens":1}},"prompt_token_ids":null}
53+
```

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,4 @@ numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
7575
partial_json_parser
7676
apache-tvm-ffi==0.1.4 # used for reduce nvidia-cutlass-dsl host overhead
7777
torch-c-dlpack-ext==0.1.3 # used for reduce nvidia-cutlass-dsl host overhead, optional package for improved torch tensor calling perf
78+
mistral-common==1.8.6

tensorrt_llm/_torch/models/checkpoints/__init__.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,30 @@
1212
from .hf.qwen3_next_weight_mapper import Qwen3NextHfWeightMapper
1313
from .hf.weight_loader import HfWeightLoader
1414
from .hf.weight_mapper import HfWeightMapper
15+
from .mistral.checkpoint_loader import (MistralCheckpointLoader,
16+
MistralLarge3CheckpointLoader)
17+
from .mistral.config_loader import MistralConfigLoader
18+
from .mistral.weight_mapper import (MistralLarge3WeightMapper,
19+
MistralWeightMapper)
1520

1621
__all__ = [
17-
"HfConfigLoader", "HfWeightLoader", "HfWeightMapper",
18-
"BaseCheckpointLoader", "HfCheckpointLoader", "NemotronHHfWeightMapper",
19-
"Gemma3HfWeightMapper", "MixtralHfWeightMapper", "Llama4HfWeightMapper",
20-
"Qwen2MoeHfWeightMapper", "Qwen3MoeHfWeightMapper", "Qwen2VLHfWeightMapper",
21-
"Qwen3NextHfWeightMapper", "LlavaNextHfWeightMapper"
22+
"HfConfigLoader",
23+
"HfWeightLoader",
24+
"HfWeightMapper",
25+
"MistralConfigLoader",
26+
"MistralWeightMapper",
27+
"MistralCheckpointLoader",
28+
"BaseCheckpointLoader",
29+
"HfCheckpointLoader",
30+
"NemotronHHfWeightMapper",
31+
"Gemma3HfWeightMapper",
32+
"MixtralHfWeightMapper",
33+
"Llama4HfWeightMapper",
34+
"Qwen2MoeHfWeightMapper",
35+
"Qwen3MoeHfWeightMapper",
36+
"Qwen2VLHfWeightMapper",
37+
"Qwen3NextHfWeightMapper",
38+
"LlavaNextHfWeightMapper",
39+
"MistralLarge3CheckpointLoader",
40+
"MistralLarge3WeightMapper",
2241
]

tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tensorrt_llm.mapping import Mapping
2020

2121

22+
@register_checkpoint_weight_loader("mistral")
2223
@register_checkpoint_weight_loader("HF")
2324
class HfWeightLoader(BaseWeightLoader):
2425
"""

tensorrt_llm/_torch/models/checkpoints/mistral/__init__.py

Whitespace-only changes.
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
2+
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader
3+
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper
4+
from tensorrt_llm._torch.models.checkpoints.hf.checkpoint_loader import HfCheckpointLoader
5+
from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import MistralConfigLoader
6+
from tensorrt_llm._torch.models.modeling_utils import register_checkpoint_loader
7+
8+
9+
@register_checkpoint_loader("mistral")
10+
class MistralCheckpointLoader(HfCheckpointLoader):
11+
def __init__(
12+
self,
13+
*,
14+
weight_loader: BaseWeightLoader | None = None,
15+
weight_mapper: BaseWeightMapper | None = None,
16+
config_loader: BaseConfigLoader | None = None,
17+
):
18+
super().__init__(
19+
weight_loader=weight_loader, weight_mapper=weight_mapper, config_loader=config_loader
20+
)
21+
self._checkpoint_format = "mistral"
22+
self.mm_module_mapping = {
23+
"vision_encoder": "vision_tower",
24+
"pre_mm_projector_norm": "multi_modal_projector.norm",
25+
"vision_language_adapter": "multi_modal_projector",
26+
"patch_merger": "multi_modal_projector.patch_merger",
27+
}
28+
29+
def preprocess_weights(self, weights: dict) -> dict:
30+
"""
31+
Aggregate weights by module
32+
"""
33+
hf_weights = {}
34+
35+
for key, value in weights.items():
36+
modules = key.split(".")
37+
38+
if modules[0] not in self.mm_module_mapping.keys():
39+
hf_weights["language_model." + key] = value
40+
41+
else:
42+
modules[0] = self.mm_module_mapping[modules[0]]
43+
hf_weights[".".join(modules)] = value
44+
45+
return hf_weights
46+
47+
def inverse_nvfp4_global_scales(self, weights):
48+
for key in weights.keys():
49+
if "global_scale" in key:
50+
weights[key] = 1.0 / weights[key]
51+
52+
def load_weights(self, checkpoint_dir: str, **kwargs):
53+
weights = super().weight_loader.load_weights(checkpoint_dir, **kwargs)
54+
weights = self.preprocess_weights(weights)
55+
# The definition of global_scale is different in Mistral, need to inverse the scale
56+
self.inverse_nvfp4_global_scales(weights)
57+
return weights
58+
59+
def get_default_config_loader(self) -> MistralConfigLoader:
60+
return MistralConfigLoader()
61+
62+
63+
@register_checkpoint_loader("mistral_large_3")
64+
class MistralLarge3CheckpointLoader(MistralCheckpointLoader):
65+
def __init__(
66+
self,
67+
*,
68+
weight_loader: BaseWeightLoader | None = None,
69+
weight_mapper: BaseWeightMapper | None = None,
70+
config_loader: BaseConfigLoader | None = None,
71+
):
72+
super().__init__(
73+
weight_loader=weight_loader, weight_mapper=weight_mapper, config_loader=config_loader
74+
)
75+
self._checkpoint_format = "mistral_large_3"

0 commit comments

Comments
 (0)