Skip to content

Commit dd10acf

Browse files
committed
Fix passing torch_dtype and device_map via model_init_kwargs
Signed-off-by: Thomas Parnell <[email protected]>
1 parent 147fa4d commit dd10acf

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

auto_fp8/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def skip(*args, **kwargs):
7171
torch.cuda.empty_cache()
7272

7373
# Important defaults
74-
if not hasattr(model_init_kwargs, "torch_dtype"):
74+
if not "torch_dtype" in model_init_kwargs:
7575
model_init_kwargs["torch_dtype"] = "auto"
7676

77-
if not hasattr(model_init_kwargs, "device_map"):
77+
if not "device_map" in model_init_kwargs:
7878
model_init_kwargs["device_map"] = "auto"
7979

8080
merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}

0 commit comments

Comments
 (0)