@@ -279,13 +279,14 @@ def test_int4wo_quant_bfloat16_conversion(self):
279279 self .assertEqual (weight .quant_max , 15 )
280280 self .assertTrue (isinstance (weight .layout_type , TensorCoreTiledLayoutType ))
281281
282- def test_offload (self ):
282+ def test_device_map (self ):
283283 """
284- Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies
285- that the device map is correctly set (in the `hf_device_map` attribute of the model).
284+ Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
285+ The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
286+ correctly set (in the `hf_device_map` attribute of the model).
286287 """
287288
288- device_map_offload = {
289+ custom_device_map_dict = {
289290 "time_text_embed" : torch_device ,
290291 "context_embedder" : torch_device ,
291292 "x_embedder" : torch_device ,
@@ -294,27 +295,51 @@ def test_offload(self):
294295 "norm_out" : torch_device ,
295296 "proj_out" : "cpu" ,
296297 }
298+ device_maps = ["auto" , custom_device_map_dict ]
297299
298300 inputs = self .get_dummy_tensor_inputs (torch_device )
299-
300- with tempfile .TemporaryDirectory () as offload_folder :
301- quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
302- quantized_model = FluxTransformer2DModel .from_pretrained (
303- "hf-internal-testing/tiny-flux-pipe" ,
304- subfolder = "transformer" ,
305- quantization_config = quantization_config ,
306- device_map = device_map_offload ,
307- torch_dtype = torch .bfloat16 ,
308- offload_folder = offload_folder ,
309- )
310-
311- self .assertTrue (quantized_model .hf_device_map == device_map_offload )
312-
313- output = quantized_model (** inputs )[0 ]
314- output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
315-
316- expected_slice = np .array ([0.3457 , - 0.0366 , 0.0105 , - 0.2275 , - 0.4941 , 0.4395 , - 0.166 , - 0.6641 , 0.4375 ])
317- self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
301+ expected_slice = np .array ([0.3457 , - 0.0366 , 0.0105 , - 0.2275 , - 0.4941 , 0.4395 , - 0.166 , - 0.6641 , 0.4375 ])
302+
303+ for device_map in device_maps :
304+ device_map_to_compare = {"" : 0 } if device_map == "auto" else device_map
305+
306+ # Test non-sharded model
307+ with tempfile .TemporaryDirectory () as offload_folder :
308+ quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
309+ quantized_model = FluxTransformer2DModel .from_pretrained (
310+ "hf-internal-testing/tiny-flux-pipe" ,
311+ subfolder = "transformer" ,
312+ quantization_config = quantization_config ,
313+ device_map = device_map ,
314+ torch_dtype = torch .bfloat16 ,
315+ offload_folder = offload_folder ,
316+ )
317+
318+ self .assertTrue (quantized_model .hf_device_map == device_map_to_compare )
319+
320+ output = quantized_model (** inputs )[0 ]
321+ output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
322+ self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
323+
324+ # Test sharded model
325+ with tempfile .TemporaryDirectory () as offload_folder :
326+ quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
327+ quantized_model = FluxTransformer2DModel .from_pretrained (
328+ "hf-internal-testing/tiny-flux-sharded" ,
329+ subfolder = "transformer" ,
330+ quantization_config = quantization_config ,
331+ device_map = device_map ,
332+ torch_dtype = torch .bfloat16 ,
333+ offload_folder = offload_folder ,
334+ )
335+
336+ self .assertTrue (quantized_model .hf_device_map == device_map_to_compare )
337+
338+ output = quantized_model (** inputs )[0 ]
339+ output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
340+
341+ expected_slice = np .array ([0.3457 , - 0.0366 , 0.0105 , - 0.2275 , - 0.4941 , 0.4395 , - 0.166 , - 0.6641 , 0.4375 ])
342+ self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
318343
319344 def test_modules_to_not_convert (self ):
320345 quantization_config = TorchAoConfig ("int8_weight_only" , modules_to_not_convert = ["transformer_blocks.0" ])
0 commit comments