|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | import gc |
| 16 | +import os |
16 | 17 | import tempfile |
17 | 18 | import unittest |
18 | 19 |
|
19 | 20 | import numpy as np |
| 21 | +import safetensors.torch |
20 | 22 |
|
21 | 23 | from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel |
22 | 24 | from diffusers.utils import logging |
@@ -118,6 +120,9 @@ def get_dummy_inputs(self): |
118 | 120 |
|
119 | 121 | class BnB4BitBasicTests(Base4bitTests): |
120 | 122 | def setUp(self): |
| 123 | + gc.collect() |
| 124 | + torch.cuda.empty_cache() |
| 125 | + |
121 | 126 | # Models |
122 | 127 | self.model_fp16 = SD3Transformer2DModel.from_pretrained( |
123 | 128 | self.model_name, subfolder="transformer", torch_dtype=torch.float16 |
@@ -232,7 +237,7 @@ def test_linear_are_4bit(self): |
232 | 237 |
|
233 | 238 | def test_config_from_pretrained(self): |
234 | 239 | transformer_4bit = FluxTransformer2DModel.from_pretrained( |
235 | | - "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer" |
| 240 | + "hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer" |
236 | 241 | ) |
237 | 242 | linear = get_some_linear_layer(transformer_4bit) |
238 | 243 | self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) |
@@ -312,9 +317,42 @@ def test_bnb_4bit_wrong_config(self): |
312 | 317 | with self.assertRaises(ValueError): |
313 | 318 | _ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add") |
314 | 319 |
|
| 320 | + def test_bnb_4bit_errors_loading_incorrect_state_dict(self): |
| 321 | + r""" |
| 322 | + Test if loading with an incorrect state dict raises an error. |
| 323 | + """ |
| 324 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 325 | + nf4_config = BitsAndBytesConfig(load_in_4bit=True) |
| 326 | + model_4bit = SD3Transformer2DModel.from_pretrained( |
| 327 | + self.model_name, subfolder="transformer", quantization_config=nf4_config |
| 328 | + ) |
| 329 | + model_4bit.save_pretrained(tmpdirname) |
| 330 | + del model_4bit |
| 331 | + |
| 332 | + with self.assertRaises(ValueError) as err_context: |
| 333 | + state_dict = safetensors.torch.load_file( |
| 334 | + os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors") |
| 335 | + ) |
| 336 | + |
| 337 | + # corrupt the state dict |
| 338 | + key_to_target = "context_embedder.weight" # can be other keys too. |
| 339 | + compatible_param = state_dict[key_to_target] |
| 340 | + corrupted_param = torch.randn(compatible_param.shape[0] - 1, 1) |
| 341 | + state_dict[key_to_target] = bnb.nn.Params4bit(corrupted_param, requires_grad=False) |
| 342 | + safetensors.torch.save_file( |
| 343 | + state_dict, os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors") |
| 344 | + ) |
| 345 | + |
| 346 | + _ = SD3Transformer2DModel.from_pretrained(tmpdirname) |
| 347 | + |
| 348 | + assert key_to_target in str(err_context.exception) |
| 349 | + |
315 | 350 |
|
316 | 351 | class BnB4BitTrainingTests(Base4bitTests): |
317 | 352 | def setUp(self): |
| 353 | + gc.collect() |
| 354 | + torch.cuda.empty_cache() |
| 355 | + |
318 | 356 | nf4_config = BitsAndBytesConfig( |
319 | 357 | load_in_4bit=True, |
320 | 358 | bnb_4bit_quant_type="nf4", |
@@ -360,6 +398,9 @@ def test_training(self): |
360 | 398 | @require_transformers_version_greater("4.44.0") |
361 | 399 | class SlowBnb4BitTests(Base4bitTests): |
362 | 400 | def setUp(self) -> None: |
| 401 | + gc.collect() |
| 402 | + torch.cuda.empty_cache() |
| 403 | + |
363 | 404 | nf4_config = BitsAndBytesConfig( |
364 | 405 | load_in_4bit=True, |
365 | 406 | bnb_4bit_quant_type="nf4", |
@@ -447,8 +488,10 @@ def test_moving_to_cpu_throws_warning(self): |
447 | 488 | @require_transformers_version_greater("4.44.0") |
448 | 489 | class SlowBnb4BitFluxTests(Base4bitTests): |
449 | 490 | def setUp(self) -> None: |
450 | | - # TODO: Copy sayakpaul/flux.1-dev-nf4-pkg to testing repo. |
451 | | - model_id = "sayakpaul/flux.1-dev-nf4-pkg" |
| 491 | + gc.collect() |
| 492 | + torch.cuda.empty_cache() |
| 493 | + |
| 494 | + model_id = "hf-internal-testing/flux.1-dev-nf4-pkg" |
452 | 495 | t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") |
453 | 496 | transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer") |
454 | 497 | self.pipeline_4bit = DiffusionPipeline.from_pretrained( |
|
0 commit comments