Skip to content

Commit 11cfd6c

Browse files
committed
modularize.
1 parent edf66b7 commit 11cfd6c

File tree

3 files changed

+49
-23
lines changed

3 files changed

+49
-23
lines changed

tests/quantization/bnb/test_4bit.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
torch_device,
5252
)
5353

54-
from ..utils import QuantCompileMiscTests
54+
from ..test_torch_compile_utils import QuantCompileMiscTests
5555

5656

5757
def get_some_linear_layer(model):
@@ -861,18 +861,24 @@ def test_fp4_double_safe(self):
861861
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
862862

863863

864+
@require_torch_version_greater("2.7.1")
864865
class Bnb4BitCompileTests(QuantCompileMiscTests):
865-
@require_torch_version_greater("2.7.1")
866+
quantization_config = PipelineQuantizationConfig(
867+
quant_backend="bitsandbytes_8bit",
868+
quant_kwargs={
869+
"load_in_4bit": True,
870+
"bnb_4bit_quant_type": "nf4",
871+
"bnb_4bit_compute_dtype": torch.bfloat16,
872+
},
873+
components_to_quantize=["transformer", "text_encoder_2"],
874+
)
875+
866876
def test_torch_compile(self):
867877
torch._dynamo.config.capture_dynamic_output_shape_ops = True
878+
super()._test_torch_compile(quantization_config=self.quantization_config)
868879

869-
quantization_config = PipelineQuantizationConfig(
870-
quant_backend="bitsandbytes_4bit",
871-
quant_kwargs={
872-
"load_in_4bit": True,
873-
"bnb_4bit_quant_type": "nf4",
874-
"bnb_4bit_compute_dtype": torch.bfloat16,
875-
},
876-
components_to_quantize=["transformer", "text_encoder_2"],
880+
def test_torch_compile_with_cpu_offload(self):
881+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
882+
super()._test_torch_compile_with_cpu_offload(
883+
quantization_config=self.quantization_config, torch_dtype=torch.float16
877884
)
878-
super().test_torch_compile(quantization_config=quantization_config)

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
torch_device,
5050
)
5151

52-
from ..utils import QuantCompileMiscTests
52+
from ..test_torch_compile_utils import QuantCompileMiscTests
5353

5454

5555
def get_some_linear_layer(model):
@@ -779,16 +779,20 @@ def test_serialization_sharded(self):
779779
self.assertTrue(torch.equal(out_0, out_1))
780780

781781

782+
@require_torch_version_greater_equal("2.6.0")
782783
class Bnb8BitCompileTests(QuantCompileMiscTests):
783-
@require_torch_version_greater_equal("2.6.0")
784+
quantization_config = PipelineQuantizationConfig(
785+
quant_backend="bitsandbytes_8bit",
786+
quant_kwargs={"load_in_8bit": True},
787+
components_to_quantize=["transformer", "text_encoder_2"],
788+
)
789+
784790
def test_torch_compile(self):
785791
torch._dynamo.config.capture_dynamic_output_shape_ops = True
792+
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
786793

787-
quantization_config = PipelineQuantizationConfig(
788-
quant_backend="bitsandbytes_8bit",
789-
quant_kwargs={
790-
"load_in_8bit": True,
791-
},
792-
components_to_quantize=["transformer", "text_encoder_2"],
794+
def test_torch_compile_with_cpu_offload(self):
795+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
796+
super()._test_torch_compile_with_cpu_offload(
797+
quantization_config=self.quantization_config, torch_dtype=torch.float16
793798
)
794-
super().test_torch_compile(quantization_config=quantization_config, torch_dtype=torch.float16)

tests/quantization/test_torch_compile_utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
@require_torch_gpu
2525
@slow
2626
class QuantCompileMiscTests(unittest.TestCase):
27+
quantization_config = None
28+
2729
def setUp(self):
2830
super().setUp()
2931
gc.collect()
@@ -36,14 +38,28 @@ def tearDown(self):
3638
backend_empty_cache(torch_device)
3739
torch.compiler.reset()
3840

39-
def test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
41+
def _init_pipeline(self, quantization_config, torch_dtype):
4042
pipe = DiffusionPipeline.from_pretrained(
4143
"stabilityai/stable-diffusion-3-medium-diffusers",
4244
quantization_config=quantization_config,
4345
torch_dtype=torch_dtype,
44-
).to("cuda")
46+
)
47+
return pipe
48+
49+
def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
50+
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
51+
# import to ensure fullgraph True
4552
pipe.transformer.compile(fullgraph=True)
4653

4754
for _ in range(2):
4855
# small resolutions to ensure speedy execution.
49-
pipe("a dog", num_inference_steps=4, max_sequence_length=16, height=256, width=256)
56+
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
57+
58+
def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
59+
pipe = self._init_pipeline(quantization_config, torch_dtype)
60+
pipe.enable_model_cpu_offload()
61+
pipe.transformer.compile()
62+
63+
for _ in range(2):
64+
# small resolutions to ensure speedy execution.
65+
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)

0 commit comments

Comments
 (0)