Skip to content

Commit fc895fd

Browse files
authored
Merge pull request #12 from tdoublep/fix-model-init
Fix passing torch_dtype and device_map via model_init_kwargs
2 parents 6137606 + bde0030 commit fc895fd

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

auto_fp8/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def skip(*args, **kwargs):
7171
torch.cuda.empty_cache()
7272

7373
# Important defaults
74-
if not hasattr(model_init_kwargs, "torch_dtype"):
74+
if "torch_dtype" not in model_init_kwargs:
7575
model_init_kwargs["torch_dtype"] = "auto"
7676

77-
if not hasattr(model_init_kwargs, "device_map"):
77+
if "device_map" not in model_init_kwargs:
7878
model_init_kwargs["device_map"] = "auto"
7979

8080
merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}

0 commit comments

Comments
 (0)