Skip to content

Commit 1873bb7

Browse files
committed
update device map tests
1 parent bc47057 commit 1873bb7

File tree

1 file changed

+55
-56
lines changed

1 file changed

+55
-56
lines changed

tests/quantization/torchao/test_torchao.py

Lines changed: 55 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ def test_int4wo_quant_bfloat16_conversion(self):
281281
self.assertEqual(weight.quant_min, 0)
282282
self.assertEqual(weight.quant_max, 15)
283283

284-
@unittest.skip("Device map is not yet supported for TorchAO quantization.")
285284
def test_device_map(self):
286285
# Note: We were not checking if the weight tensor's were AffineQuantizedTensor's before. If we did
287286
# it would have errored out. Now, we do. So, device_map basically never worked with or without
@@ -291,65 +290,65 @@ def test_device_map(self):
291290
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
292291
correctly set (in the `hf_device_map` attribute of the model).
293292
"""
294-
pass
295-
# custom_device_map_dict = {
296-
# "time_text_embed": torch_device,
297-
# "context_embedder": torch_device,
298-
# "x_embedder": torch_device,
299-
# "transformer_blocks.0": "cpu",
300-
# "single_transformer_blocks.0": "disk",
301-
# "norm_out": torch_device,
302-
# "proj_out": "cpu",
303-
# }
304-
# device_maps = ["auto", custom_device_map_dict]
293+
custom_device_map_dict = {
294+
"time_text_embed": torch_device,
295+
"context_embedder": torch_device,
296+
"x_embedder": torch_device,
297+
"transformer_blocks.0": "cpu",
298+
"single_transformer_blocks.0": "disk",
299+
"norm_out": torch_device,
300+
"proj_out": "cpu",
301+
}
302+
device_maps = ["auto", custom_device_map_dict]
305303

306304
# inputs = self.get_dummy_tensor_inputs(torch_device)
307305
# expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
308306

309-
# for device_map in device_maps:
310-
# device_map_to_compare = {"": 0} if device_map == "auto" else device_map
311-
312-
# # Test non-sharded model - should work
313-
# with tempfile.TemporaryDirectory() as offload_folder:
314-
# quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
315-
# quantized_model = FluxTransformer2DModel.from_pretrained(
316-
# "hf-internal-testing/tiny-flux-pipe",
317-
# subfolder="transformer",
318-
# quantization_config=quantization_config,
319-
# device_map=device_map,
320-
# torch_dtype=torch.bfloat16,
321-
# offload_folder=offload_folder,
322-
# )
323-
324-
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
325-
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
326-
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
327-
328-
# output = quantized_model(**inputs)[0]
329-
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
330-
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
331-
332-
# # Test sharded model - should not work
333-
# with self.assertRaises(ValueError):
334-
# with tempfile.TemporaryDirectory() as offload_folder:
335-
# quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
336-
# quantized_model = FluxTransformer2DModel.from_pretrained(
337-
# "hf-internal-testing/tiny-flux-sharded",
338-
# subfolder="transformer",
339-
# quantization_config=quantization_config,
340-
# device_map=device_map,
341-
# torch_dtype=torch.bfloat16,
342-
# offload_folder=offload_folder,
343-
# )
344-
345-
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
346-
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
347-
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
348-
349-
# output = quantized_model(**inputs)[0]
350-
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
351-
352-
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
307+
for device_map in device_maps:
308+
# device_map_to_compare = {"": 0} if device_map == "auto" else device_map
309+
310+
# Test non-sharded model - should work
311+
with self.assertRaises(NotImplementedError):
312+
with tempfile.TemporaryDirectory() as offload_folder:
313+
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
314+
_ = FluxTransformer2DModel.from_pretrained(
315+
"hf-internal-testing/tiny-flux-pipe",
316+
subfolder="transformer",
317+
quantization_config=quantization_config,
318+
device_map=device_map,
319+
torch_dtype=torch.bfloat16,
320+
offload_folder=offload_folder,
321+
)
322+
323+
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
324+
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
325+
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
326+
327+
# output = quantized_model(**inputs)[0]
328+
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
329+
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
330+
331+
# Test sharded model - should not work
332+
with self.assertRaises(NotImplementedError):
333+
with tempfile.TemporaryDirectory() as offload_folder:
334+
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
335+
_ = FluxTransformer2DModel.from_pretrained(
336+
"hf-internal-testing/tiny-flux-sharded",
337+
subfolder="transformer",
338+
quantization_config=quantization_config,
339+
device_map=device_map,
340+
torch_dtype=torch.bfloat16,
341+
offload_folder=offload_folder,
342+
)
343+
344+
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
345+
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
346+
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
347+
348+
# output = quantized_model(**inputs)[0]
349+
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
350+
351+
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
353352

354353
def test_modules_to_not_convert(self):
355354
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])

0 commit comments

Comments
 (0)