1818import  math 
1919import  re 
2020import  urllib .parse  as  ul 
21- from  typing  import  Callable , List , Optional , Tuple , Union 
21+ from  typing  import  Callable , Dict ,  List , Optional , Tuple , Union 
2222
2323import  torch 
24- import  tqdm 
2524from  transformers  import  T5EncoderModel , T5Tokenizer 
2625
26+ from  ...callbacks  import  MultiPipelineCallbacks , PipelineCallback 
2727from  ...models  import  AllegroTransformer3DModel , AutoencoderKLAllegro 
2828from  ...models .embeddings  import  get_3d_rotary_pos_embed_allegro 
2929from  ...pipelines .pipeline_utils  import  DiffusionPipeline 
@@ -171,6 +171,12 @@ class AllegroPipeline(DiffusionPipeline):
171171    _optional_components  =  ["tokenizer" , "text_encoder" , "vae" , "transformer" , "scheduler" ]
172172    model_cpu_offload_seq  =  "text_encoder->transformer->vae" 
173173
174+     _callback_tensor_inputs  =  [
175+         "latents" ,
176+         "prompt_embeds" ,
177+         "negative_prompt_embeds" ,
178+     ]
179+ 
174180    def  __init__ (
175181        self ,
176182        tokenizer : T5Tokenizer ,
@@ -198,7 +204,7 @@ def encode_prompt(
198204        prompt : Union [str , List [str ]],
199205        do_classifier_free_guidance : bool  =  True ,
200206        negative_prompt : str  =  "" ,
201-         num_images_per_prompt : int  =  1 ,
207+         num_videos_per_prompt : int  =  1 ,
202208        device : Optional [torch .device ] =  None ,
203209        prompt_embeds : Optional [torch .FloatTensor ] =  None ,
204210        negative_prompt_embeds : Optional [torch .FloatTensor ] =  None ,
@@ -286,10 +292,10 @@ def encode_prompt(
286292
287293        bs_embed , seq_len , _  =  prompt_embeds .shape 
288294        # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 
289-         prompt_embeds  =  prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
290-         prompt_embeds  =  prompt_embeds .view (bs_embed  *  num_images_per_prompt , seq_len , - 1 )
295+         prompt_embeds  =  prompt_embeds .repeat (1 , num_videos_per_prompt , 1 )
296+         prompt_embeds  =  prompt_embeds .view (bs_embed  *  num_videos_per_prompt , seq_len , - 1 )
291297        prompt_attention_mask  =  prompt_attention_mask .view (bs_embed , - 1 )
292-         prompt_attention_mask  =  prompt_attention_mask .repeat (num_images_per_prompt , 1 )
298+         prompt_attention_mask  =  prompt_attention_mask .repeat (num_videos_per_prompt , 1 )
293299
294300        # get unconditional embeddings for classifier free guidance 
295301        if  do_classifier_free_guidance  and  negative_prompt_embeds  is  None :
@@ -320,11 +326,11 @@ def encode_prompt(
320326
321327            negative_prompt_embeds  =  negative_prompt_embeds .to (dtype = dtype , device = device )
322328
323-             negative_prompt_embeds  =  negative_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
324-             negative_prompt_embeds  =  negative_prompt_embeds .view (batch_size  *  num_images_per_prompt , seq_len , - 1 )
329+             negative_prompt_embeds  =  negative_prompt_embeds .repeat (1 , num_videos_per_prompt , 1 )
330+             negative_prompt_embeds  =  negative_prompt_embeds .view (batch_size  *  num_videos_per_prompt , seq_len , - 1 )
325331
326332            negative_prompt_attention_mask  =  negative_prompt_attention_mask .view (bs_embed , - 1 )
327-             negative_prompt_attention_mask  =  negative_prompt_attention_mask .repeat (num_images_per_prompt , 1 )
333+             negative_prompt_attention_mask  =  negative_prompt_attention_mask .repeat (num_videos_per_prompt , 1 )
328334        else :
329335            negative_prompt_embeds  =  None 
330336            negative_prompt_attention_mask  =  None 
@@ -355,8 +361,8 @@ def check_inputs(
355361        num_frames ,
356362        height ,
357363        width ,
358-         negative_prompt ,
359-         callback_steps ,
364+         callback_on_step_end_tensor_inputs ,
365+         negative_prompt = None ,
360366        prompt_embeds = None ,
361367        negative_prompt_embeds = None ,
362368        prompt_attention_mask = None ,
@@ -367,12 +373,11 @@ def check_inputs(
367373        if  height  %  8  !=  0  or  width  %  8  !=  0 :
368374            raise  ValueError (f"`height` and `width` have to be divisible by 8 but are { height } { width }  )
369375
370-         if  ( callback_steps  is  None )  or   (
371-             callback_steps   is   not   None   and  ( not   isinstance ( callback_steps ,  int )  or   callback_steps   <=   0 ) 
376+         if  callback_on_step_end_tensor_inputs  is  not   None   and   not   all (
377+             k   in   self . _callback_tensor_inputs   for   k   in   callback_on_step_end_tensor_inputs 
372378        ):
373379            raise  ValueError (
374-                 f"`callback_steps` has to be a positive integer but is { callback_steps }  
375-                 f" { type (callback_steps )}  
380+                 f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } { [k  for  k  in  callback_on_step_end_tensor_inputs  if  k  not  in self ._callback_tensor_inputs ]}  
376381            )
377382
378383        if  prompt  is  not None  and  prompt_embeds  is  not None :
@@ -606,20 +611,16 @@ def _prepare_rotary_positional_embeddings(
606611        num_frames : int ,
607612        device : torch .device ,
608613    ):
609-         attention_head_dim  =  96 
610-         vae_scale_factor_spatial  =  8 
611-         patch_size  =  2 
612- 
613-         grid_height  =  height  //  (vae_scale_factor_spatial  *  patch_size )
614-         grid_width  =  width  //  (vae_scale_factor_spatial  *  patch_size )
615-         base_size_width  =  1280  //  (vae_scale_factor_spatial  *  patch_size )
616-         base_size_height  =  720  //  (vae_scale_factor_spatial  *  patch_size )
614+         grid_height  =  height  //  (self .vae_scale_factor_spatial  *  self .transformer .config .patch_size )
615+         grid_width  =  width  //  (self .vae_scale_factor_spatial  *  self .transformer .config .patch_size )
616+         base_size_width  =  1280  //  (self .vae_scale_factor_spatial  *  self .transformer .config .patch_size )
617+         base_size_height  =  720  //  (self .vae_scale_factor_spatial  *  self .transformer .config .patch_size )
617618
618619        grid_crops_coords  =  get_resize_crop_region_for_grid (
619620            (grid_height , grid_width ), base_size_width , base_size_height 
620621        )
621622        freqs_t , freqs_h , freqs_w , grid_t , grid_h , grid_w  =  get_3d_rotary_pos_embed_allegro (
622-             embed_dim = attention_head_dim ,
623+             embed_dim = self . transformer . config . attention_head_dim ,
623624            crops_coords = grid_crops_coords ,
624625            grid_size = (grid_height , grid_width ),
625626            temporal_size = num_frames ,
@@ -653,10 +654,10 @@ def __call__(
653654        num_inference_steps : int  =  100 ,
654655        timesteps : List [int ] =  None ,
655656        guidance_scale : float  =  7.5 ,
656-         num_images_per_prompt : Optional [int ] =  1 ,
657657        num_frames : Optional [int ] =  None ,
658658        height : Optional [int ] =  None ,
659659        width : Optional [int ] =  None ,
660+         num_videos_per_prompt : int  =  1 ,
660661        eta : float  =  0.0 ,
661662        generator : Optional [Union [torch .Generator , List [torch .Generator ]]] =  None ,
662663        latents : Optional [torch .FloatTensor ] =  None ,
@@ -666,11 +667,12 @@ def __call__(
666667        negative_prompt_attention_mask : Optional [torch .FloatTensor ] =  None ,
667668        output_type : Optional [str ] =  "pil" ,
668669        return_dict : bool  =  True ,
669-         callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] =  None ,
670-         callback_steps : int  =  1 ,
670+         callback_on_step_end : Optional [
671+             Union [Callable [[int , int , Dict ], None ], PipelineCallback , MultiPipelineCallbacks ]
672+         ] =  None ,
673+         callback_on_step_end_tensor_inputs : List [str ] =  ["latents" ],
671674        clean_caption : bool  =  True ,
672675        max_sequence_length : int  =  300 ,
673-         verbose : bool  =  True ,
674676    ) ->  Union [AllegroPipelineOutput , Tuple ]:
675677        """ 
676678        Function invoked when calling the pipeline for generation. 
@@ -746,6 +748,12 @@ def __call__(
746748                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is 
747749                returned where the first element is a list with the generated images 
748750        """ 
751+ 
752+         if  isinstance (callback_on_step_end , (PipelineCallback , MultiPipelineCallbacks )):
753+             callback_on_step_end_tensor_inputs  =  callback_on_step_end .tensor_inputs 
754+ 
755+         num_videos_per_prompt  =  1 
756+ 
749757        # 1. Check inputs. Raise error if not correct 
750758        num_frames  =  num_frames  or  self .transformer .config .sample_size_t  *  self .vae_scale_factor_temporal 
751759        height  =  height  or  self .transformer .config .sample_size [0 ] *  self .vae_scale_factor_spatial 
@@ -756,13 +764,15 @@ def __call__(
756764            num_frames ,
757765            height ,
758766            width ,
767+             callback_on_step_end_tensor_inputs ,
759768            negative_prompt ,
760-             callback_steps ,
761769            prompt_embeds ,
762770            negative_prompt_embeds ,
763771            prompt_attention_mask ,
764772            negative_prompt_attention_mask ,
765773        )
774+         self ._guidance_scale  =  guidance_scale 
775+         self ._interrupt  =  False 
766776
767777        # 2. Default height and width to transformer 
768778        if  prompt  is  not None  and  isinstance (prompt , str ):
@@ -789,7 +799,7 @@ def __call__(
789799            prompt ,
790800            do_classifier_free_guidance ,
791801            negative_prompt = negative_prompt ,
792-             num_images_per_prompt = num_images_per_prompt ,
802+             num_videos_per_prompt = num_videos_per_prompt ,
793803            device = device ,
794804            prompt_embeds = prompt_embeds ,
795805            negative_prompt_embeds = negative_prompt_embeds ,
@@ -809,7 +819,7 @@ def __call__(
809819        # 5. Prepare latents. 
810820        latent_channels  =  self .transformer .config .in_channels 
811821        latents  =  self .prepare_latents (
812-             batch_size  *  num_images_per_prompt ,
822+             batch_size  *  num_videos_per_prompt ,
813823            latent_channels ,
814824            num_frames ,
815825            height ,
@@ -831,45 +841,56 @@ def __call__(
831841        # 8. Denoising loop 
832842        num_warmup_steps  =  max (len (timesteps ) -  num_inference_steps  *  self .scheduler .order , 0 )
833843
834-         progress_wrap  =  tqdm .tqdm  if  verbose  else  (lambda  x : x )
835-         for  i , t  in  progress_wrap (list (enumerate (timesteps ))):
836-             latent_model_input  =  torch .cat ([latents ] *  2 ) if  do_classifier_free_guidance  else  latents 
837-             latent_model_input  =  self .scheduler .scale_model_input (latent_model_input , t )
838- 
839-             # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 
840-             timestep  =  t .expand (latent_model_input .shape [0 ])
841- 
842-             if  prompt_embeds .ndim  ==  3 :
843-                 prompt_embeds  =  prompt_embeds .unsqueeze (1 )  # b l d -> b 1 l d 
844- 
845-             # prepare attention_mask. 
846-             # b c t h w -> b t h w 
847-             attention_mask  =  torch .ones_like (latent_model_input )[:, 0 ]
848- 
849-             # predict noise model_output 
850-             noise_pred  =  self .transformer (
851-                 latent_model_input ,
852-                 attention_mask = attention_mask ,
853-                 encoder_hidden_states = prompt_embeds ,
854-                 encoder_attention_mask = prompt_attention_mask ,
855-                 timestep = timestep ,
856-                 image_rotary_emb = image_rotary_emb ,
857-                 return_dict = False ,
858-             )[0 ]
859- 
860-             # perform guidance 
861-             if  do_classifier_free_guidance :
862-                 noise_pred_uncond , noise_pred_text  =  noise_pred .chunk (2 )
863-                 noise_pred  =  noise_pred_uncond  +  guidance_scale  *  (noise_pred_text  -  noise_pred_uncond )
864- 
865-             # compute previous image: x_t -> x_t-1 
866-             latents  =  self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs , return_dict = False )[0 ]
867- 
868-             # call the callback, if provided 
869-             if  i  ==  len (timesteps ) -  1  or  ((i  +  1 ) >  num_warmup_steps  and  (i  +  1 ) %  self .scheduler .order  ==  0 ):
870-                 if  callback  is  not None  and  i  %  callback_steps  ==  0 :
871-                     step_idx  =  i  //  getattr (self .scheduler , "order" , 1 )
872-                     callback (step_idx , t , latents )
844+         with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
845+             for  i , t  in  enumerate (timesteps ):
846+                 if  self .interrupt :
847+                     continue 
848+ 
849+                 latent_model_input  =  torch .cat ([latents ] *  2 ) if  do_classifier_free_guidance  else  latents 
850+                 latent_model_input  =  self .scheduler .scale_model_input (latent_model_input , t )
851+ 
852+                 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 
853+                 timestep  =  t .expand (latent_model_input .shape [0 ])
854+ 
855+                 if  prompt_embeds .ndim  ==  3 :
856+                     prompt_embeds  =  prompt_embeds .unsqueeze (1 )  # b l d -> b 1 l d 
857+ 
858+                 # prepare attention_mask. 
859+                 # b c t h w -> b t h w 
860+                 attention_mask  =  torch .ones_like (latent_model_input )[:, 0 ]
861+ 
862+                 # predict noise model_output 
863+                 noise_pred  =  self .transformer (
864+                     latent_model_input ,
865+                     attention_mask = attention_mask ,
866+                     encoder_hidden_states = prompt_embeds ,
867+                     encoder_attention_mask = prompt_attention_mask ,
868+                     timestep = timestep ,
869+                     image_rotary_emb = image_rotary_emb ,
870+                     return_dict = False ,
871+                 )[0 ]
872+ 
873+                 # perform guidance 
874+                 if  do_classifier_free_guidance :
875+                     noise_pred_uncond , noise_pred_text  =  noise_pred .chunk (2 )
876+                     noise_pred  =  noise_pred_uncond  +  guidance_scale  *  (noise_pred_text  -  noise_pred_uncond )
877+ 
878+                 # compute previous image: x_t -> x_t-1 
879+                 latents  =  self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs , return_dict = False )[0 ]
880+ 
881+                 # call the callback, if provided 
882+                 if  callback_on_step_end  is  not None :
883+                     callback_kwargs  =  {}
884+                     for  k  in  callback_on_step_end_tensor_inputs :
885+                         callback_kwargs [k ] =  locals ()[k ]
886+                     callback_outputs  =  callback_on_step_end (self , i , t , callback_kwargs )
887+ 
888+                     latents  =  callback_outputs .pop ("latents" , latents )
889+                     prompt_embeds  =  callback_outputs .pop ("prompt_embeds" , prompt_embeds )
890+                     negative_prompt_embeds  =  callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
891+ 
892+                 if  i  ==  len (timesteps ) -  1  or  ((i  +  1 ) >  num_warmup_steps  and  (i  +  1 ) %  self .scheduler .order  ==  0 ):
893+                     progress_bar .update ()
873894
874895        if  not  output_type  ==  "latent" :
875896            latents  =  latents .to (self .vae .dtype )
0 commit comments