Skip to content

Commit 9b54b3b

Browse files
authored
[None][chore] AutoDeploy: replace HF's deprecated keyword torch_dtype --> dtype (NVIDIA#8510)
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 8dc4aac commit 9b54b3b

File tree

3 files changed

+14
-13
lines changed

3 files changed

+14
-13
lines changed

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,6 @@ def __init__(self, *args, **kwargs):
107107
self.model_kwargs,
108108
)
109109

110-
# special handling for torch_dtype in model_kwargs since HF does not correctly update
111-
# torch_dtype string to an actual torch.dtype object (only with default)
112-
if "torch_dtype" in self.model_kwargs:
113-
dtype = self.model_kwargs["torch_dtype"]
114-
if isinstance(dtype, str):
115-
dtype = getattr(torch, self.model_kwargs["torch_dtype"])
116-
assert isinstance(dtype, torch.dtype), f"Invalid dtype: {dtype}"
117-
self.model_kwargs["torch_dtype"] = dtype
118-
119110
# set sharding config source to huggingface
120111
self._sharding_config["source"] = ShardingConfigSource.HUGGINGFACE
121112

@@ -159,6 +150,16 @@ def _recursive_update_config(
159150
setattr(config, key, updated_value)
160151
if child_unused:
161152
nested_unused_kwargs[key] = child_unused
153+
elif (
154+
key in ["torch_dtype", "dtype"]
155+
and isinstance(value_new, str)
156+
and value_new != "auto"
157+
):
158+
# check special handling of torch_dtype (DEPRECATED!) and dtype key to ensure we
159+
# use the correct torch.dtype object instead of a string.
160+
dtype = getattr(torch, value_new)
161+
assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}"
162+
setattr(config, key, dtype)
162163
else:
163164
# Direct update for simple values
164165
setattr(config, key, value_new)
@@ -278,7 +279,7 @@ def build_and_load_model(self, device: DeviceLikeType) -> nn.Module:
278279
"trust_remote_code": True,
279280
"tp_plan": "auto",
280281
**unused_kwargs,
281-
"torch_dtype": "auto", # takes precedence over unused_kwargs!
282+
"dtype": "auto", # takes precedence over unused_kwargs!
282283
},
283284
)
284285
model.eval()

tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
465465
"ibm-ai-platform/Bamba-9B-v2": {
466466
"llm_models_subdir": "Bamba-9B-v2",
467467
"model_kwargs": {
468-
"torch_dtype": "bfloat16",
468+
"dtype": "bfloat16",
469469
"hidden_size": 64,
470470
"intermediate_size": 128,
471471
"mamba_chunk_size": 64,
@@ -484,7 +484,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
484484
"nvidia/NVIDIA-Nemotron-Nano-12B-v2": {
485485
"llm_models_subdir": "NVIDIA-Nemotron-Nano-12B-v2",
486486
"model_kwargs": {
487-
"torch_dtype": "bfloat16",
487+
"dtype": "bfloat16",
488488
"hidden_size": 32,
489489
"intermediate_size": 64,
490490
"mamba_head_dim": 40,

tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_hybrid_patches.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_bamba_patches(model_dir: str, run_verify_generation: bool):
4848
**common_kwargs,
4949
"model_kwargs": {
5050
"use_cache": use_cache,
51-
"torch_dtype": "bfloat16",
51+
"dtype": "bfloat16",
5252
},
5353
}
5454
llm_args = AutoDeployConfig(**llm_args)

0 commit comments

Comments
 (0)