Skip to content

Commit edf66b7

Browse files
committed
make common utility.
1 parent 0e2f5b4 commit edf66b7

File tree

3 files changed

+74
-28
lines changed

3 files changed

+74
-28
lines changed

tests/quantization/bnb/test_4bit.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,14 @@
4545
require_peft_backend,
4646
require_torch,
4747
require_torch_accelerator,
48-
require_torch_gpu,
4948
require_torch_version_greater,
5049
require_transformers_version_greater,
5150
slow,
5251
torch_device,
5352
)
5453

54+
from ..utils import QuantCompileMiscTests
55+
5556

5657
def get_some_linear_layer(model):
5758
if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
@@ -860,23 +861,9 @@ def test_fp4_double_safe(self):
860861
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
861862

862863

863-
@require_torch_gpu
864-
@slow
865-
class Bnb4BitCompileTests(unittest.TestCase):
866-
def setUp(self):
867-
super().setUp()
868-
gc.collect()
869-
backend_empty_cache(torch_device)
870-
torch.compiler.reset()
871-
872-
def tearDown(self):
873-
super().tearDown()
874-
gc.collect()
875-
backend_empty_cache(torch_device)
876-
torch.compiler.reset()
877-
864+
class Bnb4BitCompileTests(QuantCompileMiscTests):
878865
@require_torch_version_greater("2.7.1")
879-
def test_torch_compile_4bit(self):
866+
def test_torch_compile(self):
880867
torch._dynamo.config.capture_dynamic_output_shape_ops = True
881868

882869
quantization_config = PipelineQuantizationConfig(
@@ -886,15 +873,6 @@ def test_torch_compile_4bit(self):
886873
"bnb_4bit_quant_type": "nf4",
887874
"bnb_4bit_compute_dtype": torch.bfloat16,
888875
},
889-
components_to_quantize=["transformer"],
876+
components_to_quantize=["transformer", "text_encoder_2"],
890877
)
891-
pipe = DiffusionPipeline.from_pretrained(
892-
"stabilityai/stable-diffusion-3-medium-diffusers",
893-
quantization_config=quantization_config,
894-
torch_dtype=torch.bfloat16,
895-
).to("cuda")
896-
pipe.transformer.compile(fullgraph=True)
897-
898-
for _ in range(2):
899-
# with torch._dynamo.config.patch(error_on_recompile=True):
900-
pipe("a dog", num_inference_steps=4, max_sequence_length=16, height=256, width=256)
878+
super().test_torch_compile(quantization_config=quantization_config)

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
SD3Transformer2DModel,
2929
logging,
3030
)
31+
from diffusers.quantizers import PipelineQuantizationConfig
3132
from diffusers.utils import is_accelerate_version
3233
from diffusers.utils.testing_utils import (
3334
CaptureLogger,
@@ -42,11 +43,14 @@
4243
require_peft_version_greater,
4344
require_torch,
4445
require_torch_accelerator,
46+
require_torch_version_greater_equal,
4547
require_transformers_version_greater,
4648
slow,
4749
torch_device,
4850
)
4951

52+
from ..utils import QuantCompileMiscTests
53+
5054

5155
def get_some_linear_layer(model):
5256
if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
@@ -773,3 +777,18 @@ def test_serialization_sharded(self):
773777
out_0 = self.model_0(**inputs)[0]
774778
out_1 = model_1(**inputs)[0]
775779
self.assertTrue(torch.equal(out_0, out_1))
780+
781+
782+
class Bnb8BitCompileTests(QuantCompileMiscTests):
783+
@require_torch_version_greater_equal("2.6.0")
784+
def test_torch_compile(self):
785+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
786+
787+
quantization_config = PipelineQuantizationConfig(
788+
quant_backend="bitsandbytes_8bit",
789+
quant_kwargs={
790+
"load_in_8bit": True,
791+
},
792+
components_to_quantize=["transformer", "text_encoder_2"],
793+
)
794+
super().test_torch_compile(quantization_config=quantization_config, torch_dtype=torch.float16)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a clone of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import gc
16+
import unittest
17+
18+
import torch
19+
20+
from diffusers import DiffusionPipeline
21+
from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device
22+
23+
24+
@require_torch_gpu
25+
@slow
26+
class QuantCompileMiscTests(unittest.TestCase):
27+
def setUp(self):
28+
super().setUp()
29+
gc.collect()
30+
backend_empty_cache(torch_device)
31+
torch.compiler.reset()
32+
33+
def tearDown(self):
34+
super().tearDown()
35+
gc.collect()
36+
backend_empty_cache(torch_device)
37+
torch.compiler.reset()
38+
39+
def test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
40+
pipe = DiffusionPipeline.from_pretrained(
41+
"stabilityai/stable-diffusion-3-medium-diffusers",
42+
quantization_config=quantization_config,
43+
torch_dtype=torch_dtype,
44+
).to("cuda")
45+
pipe.transformer.compile(fullgraph=True)
46+
47+
for _ in range(2):
48+
# small resolutions to ensure speedy execution.
49+
pipe("a dog", num_inference_steps=4, max_sequence_length=16, height=256, width=256)

0 commit comments

Comments
 (0)