Skip to content

Commit 886ea47

Browse files
committed
unify the quant compile + offloading tests.
1 parent f33b89b commit 886ea47

File tree

5 files changed

+30
-43
lines changed

5 files changed

+30
-43
lines changed

tests/quantization/bnb/test_4bit.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -888,12 +888,7 @@ def quantization_config(self):
888888

889889
def test_torch_compile(self):
890890
torch._dynamo.config.capture_dynamic_output_shape_ops = True
891-
super()._test_torch_compile(quantization_config=self.quantization_config)
892-
893-
def test_torch_compile_with_cpu_offload(self):
894-
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
891+
super().test_torch_compile()
895892

896893
def test_torch_compile_with_group_offload_leaf(self):
897-
super()._test_torch_compile_with_group_offload_leaf(
898-
quantization_config=self.quantization_config, use_stream=True
899-
)
894+
super()._test_torch_compile_with_group_offload_leaf(use_stream=True)

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -849,15 +849,11 @@ def quantization_config(self):
849849

850850
def test_torch_compile(self):
851851
torch._dynamo.config.capture_dynamic_output_shape_ops = True
852-
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
852+
super()._test_torch_compile(torch_dtype=torch.float16)
853853

854854
def test_torch_compile_with_cpu_offload(self):
855-
super()._test_torch_compile_with_cpu_offload(
856-
quantization_config=self.quantization_config, torch_dtype=torch.float16
857-
)
855+
super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
858856

859857
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
860858
def test_torch_compile_with_group_offload_leaf(self):
861-
super()._test_torch_compile_with_group_offload_leaf(
862-
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
863-
)
859+
super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)

tests/quantization/gguf/test_gguf.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -662,15 +662,6 @@ class GGUFCompileTests(QuantCompileTests):
662662
def quantization_config(self):
663663
return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
664664

665-
def test_torch_compile(self):
666-
super()._test_torch_compile(quantization_config=self.quantization_config)
667-
668-
def test_torch_compile_with_cpu_offload(self):
669-
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
670-
671-
def test_torch_compile_with_group_offload_leaf(self):
672-
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
673-
674665
def _init_pipeline(self, *args, **kwargs):
675666
transformer = FluxTransformer2DModel.from_single_file(
676667
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype

tests/quantization/test_torch_compile_utils.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,30 +50,29 @@ def _init_pipeline(self, quantization_config, torch_dtype):
5050
)
5151
return pipe
5252

53-
def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
54-
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
55-
# import to ensure fullgraph True
53+
def _test_torch_compile(self, torch_dtype=torch.bfloat16):
54+
pipe = self._init_pipeline(self.quantization_config, torch_dtype).to("cuda")
55+
# `fullgraph=True` ensures no graph breaks
5656
pipe.transformer.compile(fullgraph=True)
5757

58-
for _ in range(2):
59-
# small resolutions to ensure speedy execution.
60-
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
58+
with torch._dynamo.config.patch(error_on_recompile=True):
59+
for _ in range(2):
60+
# small resolutions to ensure speedy execution.
61+
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
6162

62-
def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
63-
pipe = self._init_pipeline(quantization_config, torch_dtype)
63+
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
64+
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
6465
pipe.enable_model_cpu_offload()
6566
pipe.transformer.compile()
6667

6768
for _ in range(2):
6869
# small resolutions to ensure speedy execution.
69-
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
70+
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
7071

71-
def _test_torch_compile_with_group_offload_leaf(
72-
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
73-
):
74-
torch._dynamo.config.cache_size_limit = 10000
72+
def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16, *, use_stream: bool = False):
73+
torch._dynamo.config.cache_size_limit = 1000
7574

76-
pipe = self._init_pipeline(quantization_config, torch_dtype)
75+
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
7776
group_offload_kwargs = {
7877
"onload_device": torch.device("cuda"),
7978
"offload_device": torch.device("cpu"),
@@ -89,4 +88,13 @@ def _test_torch_compile_with_group_offload_leaf(
8988

9089
for _ in range(2):
9190
# small resolutions to ensure speedy execution.
92-
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
91+
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
92+
93+
def test_torch_compile(self):
94+
self._test_torch_compile()
95+
96+
def test_torch_compile_with_cpu_offload(self):
97+
self._test_torch_compile_with_cpu_offload()
98+
99+
def test_torch_compile_with_group_offload_leaf(self):
100+
self._test_torch_compile_with_group_offload_leaf()

tests/quantization/torchao/test_torchao.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -639,16 +639,13 @@ def quantization_config(self):
639639
},
640640
)
641641

642-
def test_torch_compile(self):
643-
super()._test_torch_compile(quantization_config=self.quantization_config)
644-
645642
@unittest.skip(
646643
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
647644
"when compiling."
648645
)
649646
def test_torch_compile_with_cpu_offload(self):
650647
# RuntimeError: _apply(): Couldn't swap Linear.weight
651-
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
648+
super()._test_torch_compile_with_cpu_offload()
652649

653650
@unittest.skip(
654651
"""
@@ -673,7 +670,7 @@ def test_torch_compile_with_group_offload_leaf(self):
673670

674671
# For use_stream=True:
675672
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
676-
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
673+
super()._test_torch_compile_with_group_offload_leaf()
677674

678675

679676
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners

0 commit comments

Comments
 (0)