@@ -136,7 +136,7 @@ def setUp(self):
136136            bnb_4bit_compute_dtype = torch .float16 ,
137137        )
138138        self .model_4bit  =  SD3Transformer2DModel .from_pretrained (
139-             self .model_name , subfolder = "transformer" , quantization_config = nf4_config 
139+             self .model_name , subfolder = "transformer" , quantization_config = nf4_config ,  device_map = torch_device 
140140        )
141141
142142    def  tearDown (self ):
@@ -202,7 +202,7 @@ def test_keep_modules_in_fp32(self):
202202            bnb_4bit_compute_dtype = torch .float16 ,
203203        )
204204        model  =  SD3Transformer2DModel .from_pretrained (
205-             self .model_name , subfolder = "transformer" , quantization_config = nf4_config 
205+             self .model_name , subfolder = "transformer" , quantization_config = nf4_config ,  device_map = torch_device 
206206        )
207207
208208        for  name , module  in  model .named_modules ():
@@ -327,7 +327,7 @@ def test_bnb_4bit_errors_loading_incorrect_state_dict(self):
327327        with  tempfile .TemporaryDirectory () as  tmpdirname :
328328            nf4_config  =  BitsAndBytesConfig (load_in_4bit = True )
329329            model_4bit  =  SD3Transformer2DModel .from_pretrained (
330-                 self .model_name , subfolder = "transformer" , quantization_config = nf4_config 
330+                 self .model_name , subfolder = "transformer" , quantization_config = nf4_config ,  device_map = torch_device 
331331            )
332332            model_4bit .save_pretrained (tmpdirname )
333333            del  model_4bit 
@@ -362,7 +362,7 @@ def setUp(self):
362362            bnb_4bit_compute_dtype = torch .float16 ,
363363        )
364364        self .model_4bit  =  SD3Transformer2DModel .from_pretrained (
365-             self .model_name , subfolder = "transformer" , quantization_config = nf4_config 
365+             self .model_name , subfolder = "transformer" , quantization_config = nf4_config ,  device_map = torch_device 
366366        )
367367
368368    def  test_training (self ):
@@ -410,7 +410,7 @@ def setUp(self) -> None:
410410            bnb_4bit_compute_dtype = torch .float16 ,
411411        )
412412        model_4bit  =  SD3Transformer2DModel .from_pretrained (
413-             self .model_name , subfolder = "transformer" , quantization_config = nf4_config 
413+             self .model_name , subfolder = "transformer" , quantization_config = nf4_config ,  device_map = torch_device 
414414        )
415415        self .pipeline_4bit  =  DiffusionPipeline .from_pretrained (
416416            self .model_name , transformer = model_4bit , torch_dtype = torch .float16 
@@ -472,7 +472,7 @@ def test_moving_to_cpu_throws_warning(self):
472472            bnb_4bit_compute_dtype = torch .float16 ,
473473        )
474474        model_4bit  =  SD3Transformer2DModel .from_pretrained (
475-             self .model_name , subfolder = "transformer" , quantization_config = nf4_config 
475+             self .model_name , subfolder = "transformer" , quantization_config = nf4_config ,  device_map = torch_device 
476476        )
477477
478478        logger  =  logging .get_logger ("diffusers.pipelines.pipeline_utils" )
@@ -502,6 +502,7 @@ def test_pipeline_cuda_placement_works_with_nf4(self):
502502            subfolder = "transformer" ,
503503            quantization_config = transformer_nf4_config ,
504504            torch_dtype = torch .float16 ,
505+             device_map = torch_device ,
505506        )
506507        text_encoder_3_nf4_config  =  BnbConfig (
507508            load_in_4bit = True ,
@@ -513,6 +514,7 @@ def test_pipeline_cuda_placement_works_with_nf4(self):
513514            subfolder = "text_encoder_3" ,
514515            quantization_config = text_encoder_3_nf4_config ,
515516            torch_dtype = torch .float16 ,
517+             device_map = torch_device ,
516518        )
517519        # CUDA device placement works. 
518520        pipeline_4bit  =  DiffusionPipeline .from_pretrained (
@@ -527,6 +529,94 @@ def test_pipeline_cuda_placement_works_with_nf4(self):
527529
528530        del  pipeline_4bit 
529531
532+     def  test_device_map (self ):
533+         """ 
534+         Test if the quantized model is working properly with "auto". 
535+         cpu/disk offloading as well doesn't work with bnb. 
536+         """ 
537+ 
538+         def  get_dummy_tensor_inputs (device = None , seed : int  =  0 ):
539+             batch_size  =  1 
540+             num_latent_channels  =  4 
541+             num_image_channels  =  3 
542+             height  =  width  =  4 
543+             sequence_length  =  48 
544+             embedding_dim  =  32 
545+ 
546+             torch .manual_seed (seed )
547+             hidden_states  =  torch .randn ((batch_size , height  *  width , num_latent_channels )).to (
548+                 device , dtype = torch .bfloat16 
549+             )
550+             torch .manual_seed (seed )
551+             encoder_hidden_states  =  torch .randn ((batch_size , sequence_length , embedding_dim )).to (
552+                 device , dtype = torch .bfloat16 
553+             )
554+ 
555+             torch .manual_seed (seed )
556+             pooled_prompt_embeds  =  torch .randn ((batch_size , embedding_dim )).to (device , dtype = torch .bfloat16 )
557+ 
558+             torch .manual_seed (seed )
559+             text_ids  =  torch .randn ((sequence_length , num_image_channels )).to (device , dtype = torch .bfloat16 )
560+ 
561+             torch .manual_seed (seed )
562+             image_ids  =  torch .randn ((height  *  width , num_image_channels )).to (device , dtype = torch .bfloat16 )
563+ 
564+             timestep  =  torch .tensor ([1.0 ]).to (device , dtype = torch .bfloat16 ).expand (batch_size )
565+ 
566+             return  {
567+                 "hidden_states" : hidden_states ,
568+                 "encoder_hidden_states" : encoder_hidden_states ,
569+                 "pooled_projections" : pooled_prompt_embeds ,
570+                 "txt_ids" : text_ids ,
571+                 "img_ids" : image_ids ,
572+                 "timestep" : timestep ,
573+             }
574+ 
575+         inputs  =  get_dummy_tensor_inputs (torch_device )
576+         expected_slice  =  np .array (
577+             [0.47070312 , 0.00390625 , - 0.03662109 , - 0.19628906 , - 0.53125 , 0.5234375 , - 0.17089844 , - 0.59375 , 0.578125 ]
578+         )
579+ 
580+         # non sharded 
581+         quantization_config  =  BitsAndBytesConfig (
582+             load_in_4bit = True , bnb_4bit_quant_type = "nf4" , bnb_4bit_compute_dtype = torch .float16 
583+         )
584+         quantized_model  =  FluxTransformer2DModel .from_pretrained (
585+             "hf-internal-testing/tiny-flux-pipe" ,
586+             subfolder = "transformer" ,
587+             quantization_config = quantization_config ,
588+             device_map = "auto" ,
589+             torch_dtype = torch .bfloat16 ,
590+         )
591+ 
592+         weight  =  quantized_model .transformer_blocks [0 ].ff .net [2 ].weight 
593+         self .assertTrue (isinstance (weight , bnb .nn .modules .Params4bit ))
594+ 
595+         output  =  quantized_model (** inputs )[0 ]
596+         output_slice  =  output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
597+         self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
598+ 
599+         # sharded 
600+ 
601+         quantization_config  =  BitsAndBytesConfig (
602+             load_in_4bit = True , bnb_4bit_quant_type = "nf4" , bnb_4bit_compute_dtype = torch .float16 
603+         )
604+         quantized_model  =  FluxTransformer2DModel .from_pretrained (
605+             "hf-internal-testing/tiny-flux-sharded" ,
606+             subfolder = "transformer" ,
607+             quantization_config = quantization_config ,
608+             device_map = "auto" ,
609+             torch_dtype = torch .bfloat16 ,
610+         )
611+ 
612+         weight  =  quantized_model .transformer_blocks [0 ].ff .net [2 ].weight 
613+         self .assertTrue (isinstance (weight , bnb .nn .modules .Params4bit ))
614+ 
615+         output  =  quantized_model (** inputs )[0 ]
616+         output_slice  =  output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
617+ 
618+         self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
619+ 
530620
531621@require_transformers_version_greater ("4.44.0" ) 
532622class  SlowBnb4BitFluxTests (Base4bitTests ):
@@ -610,7 +700,10 @@ def test_serialization(self, quant_type="nf4", double_quant=True, safe_serializa
610700            bnb_4bit_compute_dtype = torch .bfloat16 ,
611701        )
612702        model_0  =  SD3Transformer2DModel .from_pretrained (
613-             self .model_name , subfolder = "transformer" , quantization_config = self .quantization_config 
703+             self .model_name ,
704+             subfolder = "transformer" ,
705+             quantization_config = self .quantization_config ,
706+             device_map = torch_device ,
614707        )
615708        self .assertTrue ("_pre_quantization_dtype"  in  model_0 .config )
616709        with  tempfile .TemporaryDirectory () as  tmpdirname :
0 commit comments