Skip to content

Commit 364472a

Browse files
authored
added wrapper for execution device (#417)
* added warpper for execution device Signed-off-by: shanjiaz <[email protected]> * break out cast_to_device Signed-off-by: shanjiaz <[email protected]> --------- Signed-off-by: shanjiaz <[email protected]>
1 parent 2154c62 commit 364472a

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
"offloaded_dispatch",
8787
"disable_offloading",
8888
"remove_dispatch",
89+
"cast_to_device",
8990
]
9091

9192

@@ -169,6 +170,19 @@ def update_parameter_data(
169170
""" Candidates for Upstreaming """
170171

171172

173+
def cast_to_device(device_spec: Union[int, torch.device]) -> torch.device:
174+
"""
175+
Convert an integer device index or torch.device into a torch.device object.
176+
177+
:param device_spec: Device index (int) or torch.device object.
178+
Negative integers map to CPU.
179+
:return: torch.device corresponding to the given device specification.
180+
"""
181+
if isinstance(device_spec, int):
182+
return torch.device(f"cuda:{device_spec}" if device_spec >= 0 else "cpu")
183+
return device_spec
184+
185+
172186
def get_execution_device(module: torch.nn.Module) -> torch.device:
173187
"""
174188
Get the device which inputs should be moved to before module execution.
@@ -179,7 +193,7 @@ def get_execution_device(module: torch.nn.Module) -> torch.device:
179193
"""
180194
for submodule in module.modules():
181195
if has_offloaded_params(submodule):
182-
return submodule._hf_hook.execution_device
196+
return cast_to_device(submodule._hf_hook.execution_device)
183197

184198
param = next(submodule.parameters(recurse=False), None)
185199
if param is not None:

0 commit comments

Comments
 (0)