@@ -75,6 +75,8 @@ class Base4bitTests(unittest.TestCase):
7575    # This was obtained on audace so the number might slightly change 
7676    expected_rel_difference  =  3.69 
7777
78+     expected_memory_saving_ratio  =  0.8 
79+ 
7880    prompt  =  "a beautiful sunset amidst the mountains." 
7981    num_inference_steps  =  10 
8082    seed  =  0 
@@ -119,8 +121,10 @@ def setUp(self):
119121        )
120122
121123    def  tearDown (self ):
122-         del  self .model_fp16 
123-         del  self .model_4bit 
124+         if  hasattr (self , "model_fp16" ):
125+             del  self .model_fp16 
126+         if  hasattr (self , "model_4bit" ):
127+             del  self .model_4bit 
124128
125129        gc .collect ()
126130        torch .cuda .empty_cache ()
@@ -159,6 +163,32 @@ def test_memory_footprint(self):
159163        linear  =  get_some_linear_layer (self .model_4bit )
160164        self .assertTrue (linear .weight .__class__  ==  bnb .nn .Params4bit )
161165
166+     def  test_model_memory_usage (self ):
167+         # Delete to not let anything interfere. 
168+         del  self .model_4bit , self .model_fp16 
169+ 
170+         # Re-instantiate. 
171+         inputs  =  self .get_dummy_inputs ()
172+         inputs  =  {
173+             k : v .to (device = torch_device , dtype = torch .float16 ) for  k , v  in  inputs .items () if  not  isinstance (v , bool )
174+         }
175+         model_fp16  =  SD3Transformer2DModel .from_pretrained (
176+             self .model_name , subfolder = "transformer" , torch_dtype = torch .float16 
177+         ).to (torch_device )
178+         unquantized_model_memory  =  get_memory_consumption_stat (model_fp16 , inputs )
179+         del  model_fp16 
180+ 
181+         nf4_config  =  BitsAndBytesConfig (
182+             load_in_4bit = True ,
183+             bnb_4bit_quant_type = "nf4" ,
184+             bnb_4bit_compute_dtype = torch .float16 ,
185+         )
186+         model_4bit  =  SD3Transformer2DModel .from_pretrained (
187+             self .model_name , subfolder = "transformer" , quantization_config = nf4_config , torch_dtype = torch .float16 
188+         )
189+         quantized_model_memory  =  get_memory_consumption_stat (model_4bit , inputs )
190+         assert  unquantized_model_memory  /  quantized_model_memory  >=  self .expected_memory_saving_ratio 
191+ 
162192    def  test_original_dtype (self ):
163193        r""" 
164194        A simple test to check if the model succesfully stores the original dtype 
@@ -329,29 +359,6 @@ def test_bnb_4bit_errors_loading_incorrect_state_dict(self):
329359
330360            assert  key_to_target  in  str (err_context .exception )
331361
332-     def  test_model_memory_usage (self ):
333-         # Delete to not let anything interfere. 
334-         del  self .model_4bit , self .model_fp16 
335-         
336-         # Re-instantiate. 
337-         inputs  =  self .get_dummy_inputs ()
338-         model_fp16  =  SD3Transformer2DModel .from_pretrained (
339-             self .model_name , subfolder = "transformer" , torch_dtype = torch .float16 
340-         )
341-         unquantized_model_memory  =  get_memory_consumption_stat (model_fp16 , inputs )
342-         nf4_config  =  BitsAndBytesConfig (
343-             load_in_4bit = True ,
344-             bnb_4bit_quant_type = "nf4" ,
345-             bnb_4bit_compute_dtype = torch .float16 ,
346-         )
347-         model_4bit  =  SD3Transformer2DModel .from_pretrained (
348-             self .model_name , subfolder = "transformer" , quantization_config = nf4_config , device_map = torch_device 
349-         )
350-         quantized_model_memory  =  get_memory_consumption_stat (model_4bit , inputs )
351-         print (f"{ unquantized_model_memory = } { quantized_model_memory = }  )
352-         assert  (1.0  -  (unquantized_model_memory  /  quantized_model_memory )) >=  100. 
353- 
354- 
355362
356363class  BnB4BitTrainingTests (Base4bitTests ):
357364    def  setUp (self ):
0 commit comments