Skip to content

Commit 739601c

Browse files
committed
add test for sharded model
1 parent c129428 commit 739601c

File tree

1 file changed

+48
-23
lines changed

1 file changed

+48
-23
lines changed

tests/quantization/torchao/test_torchao.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -279,13 +279,14 @@ def test_int4wo_quant_bfloat16_conversion(self):
279279
self.assertEqual(weight.quant_max, 15)
280280
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
281281

282-
def test_offload(self):
282+
def test_device_map(self):
283283
"""
284-
Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies
285-
that the device map is correctly set (in the `hf_device_map` attribute of the model).
284+
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
285+
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
286+
correctly set (in the `hf_device_map` attribute of the model).
286287
"""
287288

288-
device_map_offload = {
289+
custom_device_map_dict = {
289290
"time_text_embed": torch_device,
290291
"context_embedder": torch_device,
291292
"x_embedder": torch_device,
@@ -294,27 +295,51 @@ def test_offload(self):
294295
"norm_out": torch_device,
295296
"proj_out": "cpu",
296297
}
298+
device_maps = ["auto", custom_device_map_dict]
297299

298300
inputs = self.get_dummy_tensor_inputs(torch_device)
299-
300-
with tempfile.TemporaryDirectory() as offload_folder:
301-
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
302-
quantized_model = FluxTransformer2DModel.from_pretrained(
303-
"hf-internal-testing/tiny-flux-pipe",
304-
subfolder="transformer",
305-
quantization_config=quantization_config,
306-
device_map=device_map_offload,
307-
torch_dtype=torch.bfloat16,
308-
offload_folder=offload_folder,
309-
)
310-
311-
self.assertTrue(quantized_model.hf_device_map == device_map_offload)
312-
313-
output = quantized_model(**inputs)[0]
314-
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
315-
316-
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
317-
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
301+
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
302+
303+
for device_map in device_maps:
304+
device_map_to_compare = {"": 0} if device_map == "auto" else device_map
305+
306+
# Test non-sharded model
307+
with tempfile.TemporaryDirectory() as offload_folder:
308+
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
309+
quantized_model = FluxTransformer2DModel.from_pretrained(
310+
"hf-internal-testing/tiny-flux-pipe",
311+
subfolder="transformer",
312+
quantization_config=quantization_config,
313+
device_map=device_map,
314+
torch_dtype=torch.bfloat16,
315+
offload_folder=offload_folder,
316+
)
317+
318+
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
319+
320+
output = quantized_model(**inputs)[0]
321+
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
322+
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
323+
324+
# Test sharded model
325+
with tempfile.TemporaryDirectory() as offload_folder:
326+
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
327+
quantized_model = FluxTransformer2DModel.from_pretrained(
328+
"hf-internal-testing/tiny-flux-sharded",
329+
subfolder="transformer",
330+
quantization_config=quantization_config,
331+
device_map=device_map,
332+
torch_dtype=torch.bfloat16,
333+
offload_folder=offload_folder,
334+
)
335+
336+
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
337+
338+
output = quantized_model(**inputs)[0]
339+
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
340+
341+
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
342+
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
318343

319344
def test_modules_to_not_convert(self):
320345
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])

0 commit comments

Comments
 (0)