@@ -281,7 +281,6 @@ def test_int4wo_quant_bfloat16_conversion(self):
281281        self .assertEqual (weight .quant_min , 0 )
282282        self .assertEqual (weight .quant_max , 15 )
283283
284-     @unittest .skip ("Device map is not yet supported for TorchAO quantization." ) 
285284    def  test_device_map (self ):
286285        # Note: We were not checking if the weight tensor's were AffineQuantizedTensor's before. If we did 
287286        # it would have errored out. Now, we do. So, device_map basically never worked with or without 
@@ -291,65 +290,65 @@ def test_device_map(self):
291290        The custom device map performs cpu/disk offloading as well. Also verifies that the device map is 
292291        correctly set (in the `hf_device_map` attribute of the model). 
293292        """ 
294-         pass 
295-         # custom_device_map_dict = { 
296-         #     "time_text_embed": torch_device, 
297-         #     "context_embedder": torch_device, 
298-         #     "x_embedder": torch_device, 
299-         #     "transformer_blocks.0": "cpu", 
300-         #     "single_transformer_blocks.0": "disk", 
301-         #     "norm_out": torch_device, 
302-         #     "proj_out": "cpu", 
303-         # } 
304-         # device_maps = ["auto", custom_device_map_dict] 
293+         custom_device_map_dict  =  {
294+             "time_text_embed" : torch_device ,
295+             "context_embedder" : torch_device ,
296+             "x_embedder" : torch_device ,
297+             "transformer_blocks.0" : "cpu" ,
298+             "single_transformer_blocks.0" : "disk" ,
299+             "norm_out" : torch_device ,
300+             "proj_out" : "cpu" ,
301+         }
302+         device_maps  =  ["auto" , custom_device_map_dict ]
305303
306304        # inputs = self.get_dummy_tensor_inputs(torch_device) 
307305        # expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) 
308306
309-         # for device_map in device_maps: 
310-         #     device_map_to_compare = {"": 0} if device_map == "auto" else device_map 
311- 
312-         #     # Test non-sharded model - should work 
313-         #     with tempfile.TemporaryDirectory() as offload_folder: 
314-         #         quantization_config = TorchAoConfig("int4_weight_only", group_size=64) 
315-         #         quantized_model = FluxTransformer2DModel.from_pretrained( 
316-         #             "hf-internal-testing/tiny-flux-pipe", 
317-         #             subfolder="transformer", 
318-         #             quantization_config=quantization_config, 
319-         #             device_map=device_map, 
320-         #             torch_dtype=torch.bfloat16, 
321-         #             offload_folder=offload_folder, 
322-         #         ) 
323- 
324-         #         weight = quantized_model.transformer_blocks[0].ff.net[2].weight 
325-         #         self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) 
326-         #         self.assertTrue(isinstance(weight, AffineQuantizedTensor)) 
327- 
328-         #         output = quantized_model(**inputs)[0] 
329-         #         output_slice = output.flatten()[-9:].detach().float().cpu().numpy() 
330-         #         self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) 
331- 
332-         #     # Test sharded model - should not work 
333-         #     with self.assertRaises(ValueError): 
334-         #         with tempfile.TemporaryDirectory() as offload_folder: 
335-         #             quantization_config = TorchAoConfig("int4_weight_only", group_size=64) 
336-         #             quantized_model = FluxTransformer2DModel.from_pretrained( 
337-         #                 "hf-internal-testing/tiny-flux-sharded", 
338-         #                 subfolder="transformer", 
339-         #                 quantization_config=quantization_config, 
340-         #                 device_map=device_map, 
341-         #                 torch_dtype=torch.bfloat16, 
342-         #                 offload_folder=offload_folder, 
343-         #             ) 
344- 
345-         #             weight = quantized_model.transformer_blocks[0].ff.net[2].weight 
346-         #             self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) 
347-         #             self.assertTrue(isinstance(weight, AffineQuantizedTensor)) 
348- 
349-         #             output = quantized_model(**inputs)[0] 
350-         #             output_slice = output.flatten()[-9:].detach().float().cpu().numpy() 
351- 
352-         #             self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) 
307+         for  device_map  in  device_maps :
308+             # device_map_to_compare = {"": 0} if device_map == "auto" else device_map 
309+ 
310+             # Test non-sharded model - should work 
311+             with  self .assertRaises (NotImplementedError ):
312+                 with  tempfile .TemporaryDirectory () as  offload_folder :
313+                     quantization_config  =  TorchAoConfig ("int4_weight_only" , group_size = 64 )
314+                     _  =  FluxTransformer2DModel .from_pretrained (
315+                         "hf-internal-testing/tiny-flux-pipe" ,
316+                         subfolder = "transformer" ,
317+                         quantization_config = quantization_config ,
318+                         device_map = device_map ,
319+                         torch_dtype = torch .bfloat16 ,
320+                         offload_folder = offload_folder ,
321+                     )
322+ 
323+                     # weight = quantized_model.transformer_blocks[0].ff.net[2].weight 
324+                     # self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) 
325+                     # self.assertTrue(isinstance(weight, AffineQuantizedTensor)) 
326+ 
327+                     # output = quantized_model(**inputs)[0] 
328+                     # output_slice = output.flatten()[-9:].detach().float().cpu().numpy() 
329+                     # self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) 
330+ 
331+             # Test sharded model - should not work 
332+             with  self .assertRaises (NotImplementedError ):
333+                 with  tempfile .TemporaryDirectory () as  offload_folder :
334+                     quantization_config  =  TorchAoConfig ("int4_weight_only" , group_size = 64 )
335+                     _  =  FluxTransformer2DModel .from_pretrained (
336+                         "hf-internal-testing/tiny-flux-sharded" ,
337+                         subfolder = "transformer" ,
338+                         quantization_config = quantization_config ,
339+                         device_map = device_map ,
340+                         torch_dtype = torch .bfloat16 ,
341+                         offload_folder = offload_folder ,
342+                     )
343+ 
344+                     # weight = quantized_model.transformer_blocks[0].ff.net[2].weight 
345+                     # self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) 
346+                     # self.assertTrue(isinstance(weight, AffineQuantizedTensor)) 
347+ 
348+                     # output = quantized_model(**inputs)[0] 
349+                     # output_slice = output.flatten()[-9:].detach().float().cpu().numpy() 
350+ 
351+                     # self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) 
353352
354353    def  test_modules_to_not_convert (self ):
355354        quantization_config  =  TorchAoConfig ("int8_weight_only" , modules_to_not_convert = ["transformer_blocks.0" ])
0 commit comments