Skip to content

Commit a1b70d1

Browse files
committed
fixes
1 parent a81d3b8 commit a1b70d1

File tree

7 files changed

+84
-57
lines changed

7 files changed

+84
-57
lines changed

src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def create_quantized_param(
135135
target_device: "torch.device",
136136
state_dict: Dict[str, Any],
137137
unexpected_keys: Optional[List[str]] = None,
138+
**kwargs,
138139
):
139140
import bitsandbytes as bnb
140141

@@ -445,6 +446,7 @@ def create_quantized_param(
445446
target_device: "torch.device",
446447
state_dict: Dict[str, Any],
447448
unexpected_keys: Optional[List[str]] = None,
449+
**kwargs,
448450
):
449451
import bitsandbytes as bnb
450452

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def create_quantized_param(
215215
target_device: "torch.device",
216216
state_dict: Dict[str, Any],
217217
unexpected_keys: List[str],
218+
**kwargs,
218219
):
219220
r"""
220221
Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,

tests/quantization/bnb/test_4bit.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ class Base4bitTests(unittest.TestCase):
7575
# This was obtained on audace so the number might slightly change
7676
expected_rel_difference = 3.69
7777

78+
expected_memory_saving_ratio = 0.8
79+
7880
prompt = "a beautiful sunset amidst the mountains."
7981
num_inference_steps = 10
8082
seed = 0
@@ -119,8 +121,10 @@ def setUp(self):
119121
)
120122

121123
def tearDown(self):
122-
del self.model_fp16
123-
del self.model_4bit
124+
if hasattr(self, "model_fp16"):
125+
del self.model_fp16
126+
if hasattr(self, "model_4bit"):
127+
del self.model_4bit
124128

125129
gc.collect()
126130
torch.cuda.empty_cache()
@@ -159,6 +163,32 @@ def test_memory_footprint(self):
159163
linear = get_some_linear_layer(self.model_4bit)
160164
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
161165

166+
def test_model_memory_usage(self):
167+
# Delete to not let anything interfere.
168+
del self.model_4bit, self.model_fp16
169+
170+
# Re-instantiate.
171+
inputs = self.get_dummy_inputs()
172+
inputs = {
173+
k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
174+
}
175+
model_fp16 = SD3Transformer2DModel.from_pretrained(
176+
self.model_name, subfolder="transformer", torch_dtype=torch.float16
177+
).to(torch_device)
178+
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
179+
del model_fp16
180+
181+
nf4_config = BitsAndBytesConfig(
182+
load_in_4bit=True,
183+
bnb_4bit_quant_type="nf4",
184+
bnb_4bit_compute_dtype=torch.float16,
185+
)
186+
model_4bit = SD3Transformer2DModel.from_pretrained(
187+
self.model_name, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.float16
188+
)
189+
quantized_model_memory = get_memory_consumption_stat(model_4bit, inputs)
190+
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio
191+
162192
def test_original_dtype(self):
163193
r"""
164194
A simple test to check if the model succesfully stores the original dtype
@@ -329,29 +359,6 @@ def test_bnb_4bit_errors_loading_incorrect_state_dict(self):
329359

330360
assert key_to_target in str(err_context.exception)
331361

332-
def test_model_memory_usage(self):
333-
# Delete to not let anything interfere.
334-
del self.model_4bit, self.model_fp16
335-
336-
# Re-instantiate.
337-
inputs = self.get_dummy_inputs()
338-
model_fp16 = SD3Transformer2DModel.from_pretrained(
339-
self.model_name, subfolder="transformer", torch_dtype=torch.float16
340-
)
341-
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
342-
nf4_config = BitsAndBytesConfig(
343-
load_in_4bit=True,
344-
bnb_4bit_quant_type="nf4",
345-
bnb_4bit_compute_dtype=torch.float16,
346-
)
347-
model_4bit = SD3Transformer2DModel.from_pretrained(
348-
self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
349-
)
350-
quantized_model_memory = get_memory_consumption_stat(model_4bit, inputs)
351-
print(f"{unquantized_model_memory=}, {quantized_model_memory=}")
352-
assert (1.0 - (unquantized_model_memory / quantized_model_memory)) >= 100.
353-
354-
355362

356363
class BnB4BitTrainingTests(Base4bitTests):
357364
def setUp(self):

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ class Base8bitTests(unittest.TestCase):
8181
# This was obtained on audace so the number might slightly change
8282
expected_rel_difference = 1.94
8383

84+
expected_memory_saving_ratio = 0.7
85+
8486
prompt = "a beautiful sunset amidst the mountains."
8587
num_inference_steps = 10
8688
seed = 0
@@ -121,8 +123,10 @@ def setUp(self):
121123
)
122124

123125
def tearDown(self):
124-
del self.model_fp16
125-
del self.model_8bit
126+
if hasattr(self, "model_fp16"):
127+
del self.model_fp16
128+
if hasattr(self, "model_8bit"):
129+
del self.model_8bit
126130

127131
gc.collect()
128132
torch.cuda.empty_cache()
@@ -161,6 +165,28 @@ def test_memory_footprint(self):
161165
linear = get_some_linear_layer(self.model_8bit)
162166
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
163167

168+
def test_model_memory_usage(self):
169+
# Delete to not let anything interfere.
170+
del self.model_8bit, self.model_fp16
171+
172+
# Re-instantiate.
173+
inputs = self.get_dummy_inputs()
174+
inputs = {
175+
k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
176+
}
177+
model_fp16 = SD3Transformer2DModel.from_pretrained(
178+
self.model_name, subfolder="transformer", torch_dtype=torch.float16
179+
).to(torch_device)
180+
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
181+
del model_fp16
182+
183+
config = BitsAndBytesConfig(load_in_8bit=True)
184+
model_8bit = SD3Transformer2DModel.from_pretrained(
185+
self.model_name, subfolder="transformer", quantization_config=config, torch_dtype=torch.float16
186+
)
187+
quantized_model_memory = get_memory_consumption_stat(model_8bit, inputs)
188+
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio
189+
164190
def test_original_dtype(self):
165191
r"""
166192
A simple test to check if the model succesfully stores the original dtype
@@ -287,24 +313,6 @@ def test_device_and_dtype_assignment(self):
287313
# Check that this does not throw an error
288314
_ = self.model_fp16.cuda()
289315

290-
def test_model_memory_usage(self):
291-
# Delete to not let anything interfere.
292-
del self.model_4bit, self.model_fp16
293-
294-
# Re-instantiate.
295-
inputs = self.get_dummy_inputs()
296-
model_fp16 = SD3Transformer2DModel.from_pretrained(
297-
self.model_name, subfolder="transformer", torch_dtype=torch.float16
298-
)
299-
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
300-
config = BitsAndBytesConfig(load_in_8bit=True)
301-
model_8bit = SD3Transformer2DModel.from_pretrained(
302-
self.model_name, subfolder="transformer", quantization_config=config, device_map=torch_device
303-
)
304-
quantized_model_memory = get_memory_consumption_stat(model_8bit, inputs)
305-
print(f"{unquantized_model_memory=}, {quantized_model_memory=}")
306-
assert (1.0 - (unquantized_model_memory / quantized_model_memory)) >= 100.
307-
308316

309317
class Bnb8bitDeviceTests(Base8bitTests):
310318
def setUp(self) -> None:

tests/quantization/quanto/test_quanto.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
if is_torch_available():
2121
import torch
22-
22+
2323
from ..utils import LoRALayer, get_memory_consumption_stat
2424

2525

@@ -64,15 +64,20 @@ def test_quanto_layers(self):
6464
assert isinstance(module, QLinear)
6565

6666
def test_quanto_memory_usage(self):
67-
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
6867
inputs = self.get_dummy_inputs()
68+
inputs = {
69+
k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool)
70+
}
71+
72+
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
73+
unquantized_model.to(torch_device)
6974
unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs)
7075

7176
quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
77+
quantized_model.to(torch_device)
7278
quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)
7379

74-
print(f"{unquantized_model_memory=}, {quantized_model_memory=}")
75-
assert (1.0 - (unquantized_model_memory / quantized_model_memory)) >= self.expected_memory_reduction
80+
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction
7681

7782
def test_keep_modules_in_fp32(self):
7883
r"""
@@ -292,14 +297,14 @@ def test_training(self):
292297

293298

294299
class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
295-
expected_memory_reduction = 0.3
300+
expected_memory_reduction = 0.6
296301

297302
def get_dummy_init_kwargs(self):
298303
return {"weights_dtype": "float8"}
299304

300305

301306
class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
302-
expected_memory_reduction = 0.3
307+
expected_memory_reduction = 0.6
303308
_test_torch_compile = True
304309

305310
def get_dummy_init_kwargs(self):

tests/quantization/torchao/test_torchao.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -483,16 +483,21 @@ def test_memory_footprint(self):
483483
# there is additional overhead of scales and zero points
484484
self.assertTrue(total_bf16 < total_int4wo)
485485

486-
def test_memory_usage(self):
486+
def test_model_memory_usage(self):
487487
model_id = "hf-internal-testing/tiny-flux-pipe"
488-
inputs = self.get_dummy_inputs()
488+
expected_memory_saving_ratio = 2.0
489+
490+
inputs = self.get_dummy_tensor_inputs(device=torch_device)
491+
489492
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
493+
transformer_bf16.to(torch_device)
490494
unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs)
495+
del transformer_bf16
491496

492497
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
498+
transformer_int8wo.to(torch_device)
493499
quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs)
494-
print(f"{unquantized_model_memory=}, {quantized_model_memory=}")
495-
assert (1.0 - (unquantized_model_memory / quantized_model_memory)) >= 100.
500+
assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio
496501

497502
def test_wrong_config(self):
498503
with self.assertRaises(ValueError):

tests/quantization/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ def __init__(self, module: nn.Module, rank: int):
2626

2727
def forward(self, input, *args, **kwargs):
2828
return self.module(input, *args, **kwargs) + self.adapter(input)
29-
30-
29+
3130
@torch.no_grad()
3231
@torch.inference_mode()
3332
def get_memory_consumption_stat(model, inputs):

0 commit comments

Comments
 (0)