@@ -278,13 +278,14 @@ def test_int4wo_quant_bfloat16_conversion(self):
278278 self .assertEqual (weight .quant_max , 15 )
279279 self .assertTrue (isinstance (weight .layout_type , TensorCoreTiledLayoutType ))
280280
281- def test_offload (self ):
281+ def test_device_map (self ):
282282 """
283- Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies
284- that the device map is correctly set (in the `hf_device_map` attribute of the model).
283+ Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
284+ The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
285+ correctly set (in the `hf_device_map` attribute of the model).
285286 """
286287
287- device_map_offload = {
288+ custom_device_map_dict = {
288289 "time_text_embed" : torch_device ,
289290 "context_embedder" : torch_device ,
290291 "x_embedder" : torch_device ,
@@ -293,27 +294,50 @@ def test_offload(self):
293294 "norm_out" : torch_device ,
294295 "proj_out" : "cpu" ,
295296 }
297+ device_maps = ["auto" , custom_device_map_dict ]
296298
297299 inputs = self .get_dummy_tensor_inputs (torch_device )
298-
299- with tempfile .TemporaryDirectory () as offload_folder :
300- quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
301- quantized_model = FluxTransformer2DModel .from_pretrained (
302- "hf-internal-testing/tiny-flux-pipe" ,
303- subfolder = "transformer" ,
304- quantization_config = quantization_config ,
305- device_map = device_map_offload ,
306- torch_dtype = torch .bfloat16 ,
307- offload_folder = offload_folder ,
308- )
309-
310- self .assertTrue (quantized_model .hf_device_map == device_map_offload )
311-
312- output = quantized_model (** inputs )[0 ]
313- output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
314-
315- expected_slice = np .array ([0.3457 , - 0.0366 , 0.0105 , - 0.2275 , - 0.4941 , 0.4395 , - 0.166 , - 0.6641 , 0.4375 ])
316- self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
300+ expected_slice = np .array ([0.3457 , - 0.0366 , 0.0105 , - 0.2275 , - 0.4941 , 0.4395 , - 0.166 , - 0.6641 , 0.4375 ])
301+
302+ for device_map in device_maps :
303+ device_map_to_compare = {"" : 0 } if device_map == "auto" else device_map
304+
305+ # Test non-sharded model
306+ with tempfile .TemporaryDirectory () as offload_folder :
307+ quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
308+ quantized_model = FluxTransformer2DModel .from_pretrained (
309+ "hf-internal-testing/tiny-flux-pipe" ,
310+ subfolder = "transformer" ,
311+ quantization_config = quantization_config ,
312+ device_map = device_map ,
313+ torch_dtype = torch .bfloat16 ,
314+ offload_folder = offload_folder ,
315+ )
316+
317+ self .assertTrue (quantized_model .hf_device_map == device_map_to_compare )
318+
319+ output = quantized_model (** inputs )[0 ]
320+ output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
321+ self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
322+
323+ # Test sharded model
324+ with tempfile .TemporaryDirectory () as offload_folder :
325+ quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
326+ quantized_model = FluxTransformer2DModel .from_pretrained (
327+ "hf-internal-testing/tiny-flux-sharded" ,
328+ subfolder = "transformer" ,
329+ quantization_config = quantization_config ,
330+ device_map = device_map ,
331+ torch_dtype = torch .bfloat16 ,
332+ offload_folder = offload_folder ,
333+ )
334+
335+ self .assertTrue (quantized_model .hf_device_map == device_map_to_compare )
336+
337+ output = quantized_model (** inputs )[0 ]
338+ output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
339+
340+ self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
317341
318342 def test_modules_to_not_convert (self ):
319343 quantization_config = TorchAoConfig ("int8_weight_only" , modules_to_not_convert = ["transformer_blocks.0" ])
0 commit comments