1414# limitations under the License. 
1515
1616import  inspect 
17- from  typing  import  Callable , Dict , List , Optional , Tuple , Union 
17+ from  typing  import  Callable , Dict , List , Optional , Tuple , Union ,  Any 
1818
1919import  numpy  as  np 
2020import  torch 
4343    Examples: 
4444        ```python 
4545        >>> import torch 
46-         >>> from diffusers import CogView4Pipeline  
46+         >>> from diffusers import CogView4ControlPipeline  
4747
4848        >>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16) 
4949        >>> control_image = load_image( 
@@ -420,6 +420,14 @@ def do_classifier_free_guidance(self):
420420    def  num_timesteps (self ):
421421        return  self ._num_timesteps 
422422
423+     @property  
424+     def  attention_kwargs (self ):
425+         return  self ._attention_kwargs 
426+ 
427+     @property  
428+     def  current_timestep (self ):
429+         return  self ._current_timestep 
430+ 
423431    @property  
424432    def  interrupt (self ):
425433        return  self ._interrupt 
@@ -446,6 +454,7 @@ def __call__(
446454        crops_coords_top_left : Tuple [int , int ] =  (0 , 0 ),
447455        output_type : str  =  "pil" ,
448456        return_dict : bool  =  True ,
457+         attention_kwargs : Optional [Dict [str , Any ]] =  None ,
449458        callback_on_step_end : Optional [
450459            Union [Callable [[int , int , Dict ], None ], PipelineCallback , MultiPipelineCallbacks ]
451460        ] =  None ,
@@ -559,6 +568,8 @@ def __call__(
559568            negative_prompt_embeds ,
560569        )
561570        self ._guidance_scale  =  guidance_scale 
571+         self ._attention_kwargs  =  attention_kwargs 
572+         self ._current_timestep  =  None 
562573        self ._interrupt  =  False 
563574
564575        # Default call parameters 
@@ -652,6 +663,8 @@ def __call__(
652663            for  i , t  in  enumerate (timesteps ):
653664                if  self .interrupt :
654665                    continue 
666+ 
667+                 self ._current_timestep  =  t 
655668                latent_model_input  =  torch .cat ([latents , control_image ], dim = 1 ).to (transformer_dtype )
656669
657670                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 
@@ -664,6 +677,7 @@ def __call__(
664677                    original_size = original_size ,
665678                    target_size = target_size ,
666679                    crop_coords = crops_coords_top_left ,
680+                     attention_kwargs = attention_kwargs ,
667681                    return_dict = False ,
668682                )[0 ]
669683
@@ -676,6 +690,7 @@ def __call__(
676690                        original_size = original_size ,
677691                        target_size = target_size ,
678692                        crop_coords = crops_coords_top_left ,
693+                         attention_kwargs = attention_kwargs ,
679694                        return_dict = False ,
680695                    )[0 ]
681696
@@ -700,6 +715,8 @@ def __call__(
700715                if  XLA_AVAILABLE :
701716                    xm .mark_step ()
702717
718+         self ._current_timestep  =  None 
719+ 
703720        if  not  output_type  ==  "latent" :
704721            latents  =  latents .to (self .vae .dtype ) /  self .vae .config .scaling_factor 
705722            image  =  self .vae .decode (latents , return_dict = False , generator = generator )[0 ]
0 commit comments