@@ -279,14 +279,13 @@ def test_int4wo_quant_bfloat16_conversion(self):
279279 self .assertEqual (weight .quant_min , 0 )
280280 self .assertEqual (weight .quant_max , 15 )
281281
282- def test_device_map (self ):
282+ def test_offload (self ):
283283 """
284- Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
285- The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
286- correctly set (in the `hf_device_map` attribute of the model).
284+ Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies
285+ that the device map is correctly set (in the `hf_device_map` attribute of the model).
287286 """
288287
289- custom_device_map_dict = {
288+ device_map_offload = {
290289 "time_text_embed" : torch_device ,
291290 "context_embedder" : torch_device ,
292291 "x_embedder" : torch_device ,
@@ -295,50 +294,27 @@ def test_device_map(self):
295294 "norm_out" : torch_device ,
296295 "proj_out" : "cpu" ,
297296 }
298- device_maps = ["auto" , custom_device_map_dict ]
299297
300298 inputs = self .get_dummy_tensor_inputs (torch_device )
301- expected_slice = np .array ([0.3457 , - 0.0366 , 0.0105 , - 0.2275 , - 0.4941 , 0.4395 , - 0.166 , - 0.6641 , 0.4375 ])
302-
303- for device_map in device_maps :
304- device_map_to_compare = {"" : 0 } if device_map == "auto" else device_map
305-
306- # Test non-sharded model
307- with tempfile .TemporaryDirectory () as offload_folder :
308- quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
309- quantized_model = FluxTransformer2DModel .from_pretrained (
310- "hf-internal-testing/tiny-flux-pipe" ,
311- subfolder = "transformer" ,
312- quantization_config = quantization_config ,
313- device_map = device_map ,
314- torch_dtype = torch .bfloat16 ,
315- offload_folder = offload_folder ,
316- )
317-
318- self .assertTrue (quantized_model .hf_device_map == device_map_to_compare )
319-
320- output = quantized_model (** inputs )[0 ]
321- output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
322- self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
323-
324- # Test sharded model
325- with tempfile .TemporaryDirectory () as offload_folder :
326- quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
327- quantized_model = FluxTransformer2DModel .from_pretrained (
328- "hf-internal-testing/tiny-flux-sharded" ,
329- subfolder = "transformer" ,
330- quantization_config = quantization_config ,
331- device_map = device_map ,
332- torch_dtype = torch .bfloat16 ,
333- offload_folder = offload_folder ,
334- )
335-
336- self .assertTrue (quantized_model .hf_device_map == device_map_to_compare )
337-
338- output = quantized_model (** inputs )[0 ]
339- output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
340-
341- self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
299+
300+ with tempfile .TemporaryDirectory () as offload_folder :
301+ quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
302+ quantized_model = FluxTransformer2DModel .from_pretrained (
303+ "hf-internal-testing/tiny-flux-pipe" ,
304+ subfolder = "transformer" ,
305+ quantization_config = quantization_config ,
306+ device_map = device_map_offload ,
307+ torch_dtype = torch .bfloat16 ,
308+ offload_folder = offload_folder ,
309+ )
310+
311+ self .assertTrue (quantized_model .hf_device_map == device_map_offload )
312+
313+ output = quantized_model (** inputs )[0 ]
314+ output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
315+
316+ expected_slice = np .array ([0.3457 , - 0.0366 , 0.0105 , - 0.2275 , - 0.4941 , 0.4395 , - 0.166 , - 0.6641 , 0.4375 ])
317+ self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
342318
343319 def test_modules_to_not_convert (self ):
344320 quantization_config = TorchAoConfig ("int8_weight_only" , modules_to_not_convert = ["transformer_blocks.0" ])
0 commit comments