@@ -278,13 +278,14 @@ def test_int4wo_quant_bfloat16_conversion(self):
278278        self .assertEqual (weight .quant_max , 15 )
279279        self .assertTrue (isinstance (weight .layout_type , TensorCoreTiledLayoutType ))
280280
281-     def  test_offload (self ):
281+     def  test_device_map (self ):
282282        """ 
283-         Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies 
284-         that the device map is correctly set (in the `hf_device_map` attribute of the model). 
283+         Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps. 
284+         The custom device map performs cpu/disk offloading as well. Also verifies that the device map is 
285+         correctly set (in the `hf_device_map` attribute of the model). 
285286        """ 
286287
287-         device_map_offload  =  {
288+         custom_device_map_dict  =  {
288289            "time_text_embed" : torch_device ,
289290            "context_embedder" : torch_device ,
290291            "x_embedder" : torch_device ,
@@ -293,27 +294,50 @@ def test_offload(self):
293294            "norm_out" : torch_device ,
294295            "proj_out" : "cpu" ,
295296        }
297+         device_maps  =  ["auto" , custom_device_map_dict ]
296298
297299        inputs  =  self .get_dummy_tensor_inputs (torch_device )
298- 
299-         with  tempfile .TemporaryDirectory () as  offload_folder :
300-             quantization_config  =  TorchAoConfig ("int4_weight_only" , group_size = 64 )
301-             quantized_model  =  FluxTransformer2DModel .from_pretrained (
302-                 "hf-internal-testing/tiny-flux-pipe" ,
303-                 subfolder = "transformer" ,
304-                 quantization_config = quantization_config ,
305-                 device_map = device_map_offload ,
306-                 torch_dtype = torch .bfloat16 ,
307-                 offload_folder = offload_folder ,
308-             )
309- 
310-             self .assertTrue (quantized_model .hf_device_map  ==  device_map_offload )
311- 
312-             output  =  quantized_model (** inputs )[0 ]
313-             output_slice  =  output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
314- 
315-             expected_slice  =  np .array ([0.3457 , - 0.0366 , 0.0105 , - 0.2275 , - 0.4941 , 0.4395 , - 0.166 , - 0.6641 , 0.4375 ])
316-             self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
300+         expected_slice  =  np .array ([0.3457 , - 0.0366 , 0.0105 , - 0.2275 , - 0.4941 , 0.4395 , - 0.166 , - 0.6641 , 0.4375 ])
301+ 
302+         for  device_map  in  device_maps :
303+             device_map_to_compare  =  {"" : 0 } if  device_map  ==  "auto"  else  device_map 
304+ 
305+             # Test non-sharded model 
306+             with  tempfile .TemporaryDirectory () as  offload_folder :
307+                 quantization_config  =  TorchAoConfig ("int4_weight_only" , group_size = 64 )
308+                 quantized_model  =  FluxTransformer2DModel .from_pretrained (
309+                     "hf-internal-testing/tiny-flux-pipe" ,
310+                     subfolder = "transformer" ,
311+                     quantization_config = quantization_config ,
312+                     device_map = device_map ,
313+                     torch_dtype = torch .bfloat16 ,
314+                     offload_folder = offload_folder ,
315+                 )
316+ 
317+                 self .assertTrue (quantized_model .hf_device_map  ==  device_map_to_compare )
318+ 
319+                 output  =  quantized_model (** inputs )[0 ]
320+                 output_slice  =  output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
321+                 self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
322+ 
323+             # Test sharded model 
324+             with  tempfile .TemporaryDirectory () as  offload_folder :
325+                 quantization_config  =  TorchAoConfig ("int4_weight_only" , group_size = 64 )
326+                 quantized_model  =  FluxTransformer2DModel .from_pretrained (
327+                     "hf-internal-testing/tiny-flux-sharded" ,
328+                     subfolder = "transformer" ,
329+                     quantization_config = quantization_config ,
330+                     device_map = device_map ,
331+                     torch_dtype = torch .bfloat16 ,
332+                     offload_folder = offload_folder ,
333+                 )
334+ 
335+                 self .assertTrue (quantized_model .hf_device_map  ==  device_map_to_compare )
336+ 
337+                 output  =  quantized_model (** inputs )[0 ]
338+                 output_slice  =  output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
339+ 
340+                 self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
317341
318342    def  test_modules_to_not_convert (self ):
319343        quantization_config  =  TorchAoConfig ("int8_weight_only" , modules_to_not_convert = ["transformer_blocks.0" ])
0 commit comments