Skip to content

Commit ba1269d

Browse files
committed
update
1 parent 87bb2fe commit ba1269d

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/quantization/torchao/test_torchao.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ def tearDown(self):
131131
gc.collect()
132132
torch.cuda.empty_cache()
133133

134-
def get_dummy_components(self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"):
134+
def get_dummy_components(
135+
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"
136+
):
135137
transformer = FluxTransformer2DModel.from_pretrained(
136138
model_id,
137139
subfolder="transformer",
@@ -436,7 +438,9 @@ def test_memory_footprint(self):
436438
"""
437439
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
438440
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"]
439-
transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32), model_id=model_id)["transformer"]
441+
transformer_int4wo_gs32 = self.get_dummy_components(
442+
TorchAoConfig("int4wo", group_size=32), model_id=model_id
443+
)["transformer"]
440444
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
441445
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
442446

@@ -654,7 +658,7 @@ def test_quantization(self):
654658
gc.collect()
655659
torch.cuda.empty_cache()
656660
torch.cuda.synchronize()
657-
661+
658662
def test_serialization(self):
659663
quantization_config = TorchAoConfig("int8wo")
660664
components = self.get_dummy_components(quantization_config)
@@ -673,6 +677,6 @@ def test_serialization(self):
673677

674678
weight = loaded_pipe.transformer.x_embedder.weight
675679
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
676-
680+
677681
loaded_output = loaded_pipe(**inputs)[0].flatten()
678682
self.assertTrue(np.allclose(output, loaded_output, atol=1e-3, rtol=1e-3))

0 commit comments

Comments
 (0)