Skip to content

Commit ea6202b

Browse files
committed
improve replacement warnings for bnb
1 parent 844221a commit ea6202b

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

src/diffusers/quantizers/bitsandbytes/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
139139
models by reducing the precision of the weights and activations, thus making models more efficient in terms
140140
of both storage and computation.
141141
"""
142-
model, has_been_replaced = _replace_with_bnb_linear(
143-
model, modules_to_not_convert, current_key_name, quantization_config
144-
)
142+
model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config)
145143

144+
has_been_replaced = any(
145+
isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt))
146+
for _, replaced_module in model.named_modules()
147+
)
146148
if not has_been_replaced:
147149
logger.warning(
148150
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
@@ -283,13 +285,15 @@ def dequantize_and_replace(
283285
modules_to_not_convert=None,
284286
quantization_config=None,
285287
):
286-
model, has_been_replaced = _dequantize_and_replace(
288+
model, _ = _dequantize_and_replace(
287289
model,
288290
dtype=model.dtype,
289291
modules_to_not_convert=modules_to_not_convert,
290292
quantization_config=quantization_config,
291293
)
292-
294+
has_been_replaced = any(
295+
isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules()
296+
)
293297
if not has_been_replaced:
294298
logger.warning(
295299
"For some reason the model has not been properly dequantized. You might see unexpected behavior."

tests/quantization/bnb/test_4bit.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def get_some_linear_layer(model):
6363
if is_bitsandbytes_available():
6464
import bitsandbytes as bnb
6565

66+
from diffusers.quantizers.bitsandbytes.utils import replace_with_bnb_linear
67+
6668

6769
@require_bitsandbytes_version_greater("0.43.2")
6870
@require_accelerate
@@ -364,6 +366,18 @@ def test_bnb_4bit_errors_loading_incorrect_state_dict(self):
364366

365367
assert key_to_target in str(err_context.exception)
366368

369+
def test_bnb_4bit_logs_warning_for_no_quantization(self):
370+
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
371+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
372+
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
373+
logger.setLevel(30)
374+
with CaptureLogger(logger) as cap_logger:
375+
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
376+
assert (
377+
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
378+
in cap_logger.out
379+
)
380+
367381

368382
class BnB4BitTrainingTests(Base4bitTests):
369383
def setUp(self):

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def get_some_linear_layer(model):
6868
if is_bitsandbytes_available():
6969
import bitsandbytes as bnb
7070

71+
from diffusers.quantizers.bitsandbytes import replace_with_bnb_linear
72+
7173

7274
@require_bitsandbytes_version_greater("0.43.2")
7375
@require_accelerate
@@ -314,6 +316,18 @@ def test_device_and_dtype_assignment(self):
314316
# Check that this does not throw an error
315317
_ = self.model_fp16.cuda()
316318

319+
def test_bnb_8bit_logs_warning_for_no_quantization(self):
320+
model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
321+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
322+
logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
323+
logger.setLevel(30)
324+
with CaptureLogger(logger) as cap_logger:
325+
_ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
326+
assert (
327+
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
328+
in cap_logger.out
329+
)
330+
317331

318332
class Bnb8bitDeviceTests(Base8bitTests):
319333
def setUp(self) -> None:

0 commit comments

Comments
 (0)