File tree Expand file tree Collapse file tree 2 files changed +5
-2
lines changed Expand file tree Collapse file tree 2 files changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -450,7 +450,7 @@ def module_is_offloaded(module):
450450 # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
451451 if is_loaded_in_4bit_bnb and device is not None and is_transformers_version (">" , "4.44.0" ):
452452 module .to (device = device )
453- elif not is_loaded_in_8bit_bnb :
453+ elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb :
454454 module .to (device , dtype )
455455
456456 module_has_int_weights = any (
Original file line number Diff line number Diff line change @@ -260,7 +260,7 @@ def test_device_assignment(self):
260260 def test_device_and_dtype_assignment (self ):
261261 r"""
262262 Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error.
263- Checks also if other models are casted correctly.
263+ Checks also if other models are casted correctly. Device placement, however, is supported.
264264 """
265265 with self .assertRaises (ValueError ):
266266 # Tries with a `dtype`
@@ -278,6 +278,9 @@ def test_device_and_dtype_assignment(self):
278278 # Tries with a cast
279279 self .model_4bit .half ()
280280
281+ # This should work
282+ self .model_4bit .to ("cuda" )
283+
281284 # Test if we did not break anything
282285 self .model_fp16 = self .model_fp16 .to (dtype = torch .float32 , device = torch_device )
283286 input_dict_for_transformer = self .get_dummy_inputs ()
You can’t perform that action at this time.
0 commit comments