File tree Expand file tree Collapse file tree 1 file changed +15
-1
lines changed
src/compressed_tensors/utils Expand file tree Collapse file tree 1 file changed +15
-1
lines changed Original file line number Diff line number Diff line change 86
86
"offloaded_dispatch" ,
87
87
"disable_offloading" ,
88
88
"remove_dispatch" ,
89
+ "cast_to_device" ,
89
90
]
90
91
91
92
@@ -169,6 +170,19 @@ def update_parameter_data(
169
170
""" Candidates for Upstreaming """
170
171
171
172
173
+ def cast_to_device (device_spec : Union [int , torch .device ]) -> torch .device :
174
+ """
175
+ Convert an integer device index or torch.device into a torch.device object.
176
+
177
+ :param device_spec: Device index (int) or torch.device object.
178
+ Negative integers map to CPU.
179
+ :return: torch.device corresponding to the given device specification.
180
+ """
181
+ if isinstance (device_spec , int ):
182
+ return torch .device (f"cuda:{ device_spec } " if device_spec >= 0 else "cpu" )
183
+ return device_spec
184
+
185
+
172
186
def get_execution_device (module : torch .nn .Module ) -> torch .device :
173
187
"""
174
188
Get the device which inputs should be moved to before module execution.
@@ -179,7 +193,7 @@ def get_execution_device(module: torch.nn.Module) -> torch.device:
179
193
"""
180
194
for submodule in module .modules ():
181
195
if has_offloaded_params (submodule ):
182
- return submodule ._hf_hook .execution_device
196
+ return cast_to_device ( submodule ._hf_hook .execution_device )
183
197
184
198
param = next (submodule .parameters (recurse = False ), None )
185
199
if param is not None :
You can’t perform that action at this time.
0 commit comments