Skip to content

Commit 01f0836

Browse files
committed
update device passing and tests
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent f0944df commit 01f0836

File tree

16 files changed

+139
-61
lines changed

16 files changed

+139
-61
lines changed

tensorrt_llm/_torch/auto_deploy/config/transformers.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ transforms:
2020
transformers_replace_cached_attn:
2121
stage: cache_init
2222
attn_backend: flashinfer
23-
expected_layout: bsnd
2423
initialize_cache:
2524
stage: cache_init
2625
resize_kv_cache:

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
7575
"If True, only the model architecture is loaded.",
7676
)
7777

78-
# checkpoint_device: Optional[str] = Field(
79-
# default=None,
80-
# description="Device on which to load the model checkpoint. "
81-
# "Defaults to the same device as the rest of the pipeline.",
82-
# )
83-
8478
tokenizer: Optional[PathLike] = Field(
8579
description="The tokenizer",
8680
default=None,
@@ -169,6 +163,12 @@ def update_attn_page_size(self):
169163
"torch",
170164
]:
171165
self.attn_page_size = self.max_seq_len
166+
# NOTE: (hg) For transformers mode. This is ugly.
167+
if self.transforms.get("transformers_replace_cached_attn", {}).get("attn_backend") in [
168+
"triton",
169+
"torch",
170+
]:
171+
self.attn_page_size = self.max_seq_len
172172
return self
173173

174174
@field_validator("model_factory", mode="after")

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,7 @@ def build_from_config(cls, ad_config: AutoDeployConfig):
114114
# ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm.
115115

116116
# construct inference optimizer
117-
build_and_optimize = InferenceOptimizer(
118-
factory=factory, config=ad_config.transforms, local_device=device
119-
)
117+
build_and_optimize = InferenceOptimizer(factory=factory, config=ad_config.transforms)
120118

121119
# construct engine
122120
return cls(build_and_optimize, seq_info, device, max_beam_width)

tensorrt_llm/_torch/auto_deploy/transform/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class SharedConfig(BaseModel):
5555
sharding_config: ShardingConfig = Field(default_factory=ShardingConfig)
5656
local_rank: int = Field(default=0)
5757
world_size: int = Field(default=1)
58-
local_device: str = Field(description="Current rank device.")
58+
# local_device: str = Field(description="Current rank device.")
5959

6060

6161
class TransformConfig(BaseModel):

tensorrt_llm/_torch/auto_deploy/transform/library/attention.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Pattern matching for detecting repeat_kv, eager, grouped attention patterns from Huggingface models."""
22

3-
from typing import Any, Callable, Dict, List, Literal, Tuple, Type
3+
from typing import Any, Callable, Dict, List, Tuple, Type
44

55
import torch
66
import torch.nn.functional as F
@@ -496,9 +496,7 @@ def register_grouped_attention(patterns: ADPatternMatcherPass):
496496
class MatchAttentionLayoutConfig(TransformConfig):
497497
"""Configuration for the match attention layout transform."""
498498

499-
attn_backend: Literal["flashinfer", "triton", "torch"] = Field(
500-
description="Attention backend to use."
501-
)
499+
attn_backend: str = Field(description="Attention backend to use.")
502500

503501

504502
@TransformRegistry.register("match_attention_layout")

tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _apply(
8585
assert isinstance(factory, hf.AutoModelFactory), "Only HF models are supported."
8686

8787
# build and load the model
88-
model = factory.build_and_load_model(shared_config.local_device)
88+
model = factory.build_and_load_model(cm.device)
8989

9090
assert not self.config.use_strict_forward, "Only regular forward is supported."
9191

tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ class ResizeKVCacheConfig(TransformConfig):
224224
"""Configuration for the resize kv cache transform."""
225225

226226
free_mem_ratio: float = Field(
227-
description="The fraction of available memory to occupy.", default=0.8
227+
default=0.8, ge=0.0, le=1.0, description="The fraction of available memory to occupy."
228228
)
229229
args_only: bool = Field(
230230
description="Use ``*cm.args`` (default) or use ``**cm.named_args`` for the forward pass.",

tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def _apply(
4545
) -> Tuple[GraphModule, TransformInfo]:
4646
factory.load_or_random_init(
4747
gm,
48-
device=self.config.checkpoint_device or shared_config.local_device,
48+
device=self.config.checkpoint_device or cm.device,
4949
)
50-
move_to_device(gm, shared_config.local_device)
50+
move_to_device(gm, cm.device)
5151

5252
info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True)
5353

@@ -65,7 +65,9 @@ def _apply(
6565
factory: ModelFactory,
6666
shared_config: SharedConfig,
6767
) -> Tuple[GraphModule, TransformInfo]:
68-
cm.to(shared_config.local_device)
68+
# TODO (hg) This is weird but equivalent to previous code.
69+
# We does not seems to need this transform.
70+
cm.to(cm.device)
6971

7072
info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True)
7173

tensorrt_llm/_torch/auto_deploy/transform/optimizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,17 @@
2222

2323

2424
class InferenceOptimizer:
25-
def __init__(self, factory: ModelFactory, config: InferenceOptimizerConfig, local_device: str):
25+
def __init__(self, factory: ModelFactory, config: InferenceOptimizerConfig):
2626
self.factory = factory
2727
self.config = self._clean_config(config)
2828
if not dist.is_initialized():
2929
local_rank, world_size = 0, 1
3030
else:
3131
local_rank, world_size = dist_ad.get_rank_world_size()
3232
self.shared_config = SharedConfig(
33-
local_rank=local_rank, world_size=world_size, local_device=local_device
33+
local_rank=local_rank,
34+
world_size=world_size,
35+
# local_device=local_device
3436
)
3537

3638
def _clean_config(self, config: InferenceOptimizerConfig) -> StrictInferenceOptimizerConfig:

tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
440440
"mistralai/Mistral-Small-3.1-24B-Instruct-2503": {
441441
"llm_models_subdir": "Mistral-Small-3.1-24B-Instruct-2503",
442442
"model_factory": "Mistral3VLM",
443-
"compile_backend": "torch-simple",
443+
# "compile_backend": "torch-simple",
444444
"model_kwargs": {
445445
"text_config": {"num_hidden_layers": 2},
446446
"vision_config": {"num_hidden_layers": 2},
@@ -473,10 +473,8 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
473473

474474
# add some defaults to llm_args
475475
llm_args["skip_loading_weights"] = True # No weight loading to speed up things
476-
llm_args["free_mem_ratio"] = 0.00 # we don't need the cache and it may cause OOM issues
477476
llm_args["attn_page_size"] = 4 # Make sure paging is activated despite small max_tokens
478477
llm_args["max_batch_size"] = 2 # Minimum batching to speed up things
479-
480478
# update with custom llm_args kwargs
481479
llm_args.update(llm_args_kwargs)
482480

@@ -494,10 +492,16 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
494492

495493

496494
def get_small_model_config_pytest_param(
497-
model_hub_id: str, pytest_param_kwargs=None, **llm_args_kwargs
495+
model_hub_id: str,
496+
attn_backend: str,
497+
compile_backend: str,
498+
pytest_param_kwargs=None,
499+
**llm_args_kwargs,
498500
):
499501
return pytest.param(
500502
get_small_model_config(model_hub_id, **llm_args_kwargs),
503+
attn_backend,
504+
compile_backend,
501505
id=model_hub_id,
502506
**(pytest_param_kwargs or {}),
503507
)

0 commit comments

Comments
 (0)