1313# limitations under the License. 
1414
1515import  html 
16- 
1716import  types 
18- from  typing  import  Any , Callable , Dict , List , Optional , Tuple ,  Union 
17+ from  typing  import  Any , Callable , Dict , List , Optional , Union 
1918
2019import  ftfy 
2120import  regex  as  re 
2524from  diffusers .callbacks  import  MultiPipelineCallbacks , PipelineCallback 
2625from  diffusers .loaders  import  WanLoraLoaderMixin 
2726from  diffusers .models  import  AutoencoderKLWan , WanTransformer3DModel 
27+ from  diffusers .pipelines .pipeline_utils  import  DiffusionPipeline 
28+ from  diffusers .pipelines .wan .pipeline_output  import  WanPipelineOutput 
2829from  diffusers .schedulers  import  FlowMatchEulerDiscreteScheduler 
2930from  diffusers .utils  import  is_torch_xla_available , logging , replace_example_docstring 
3031from  diffusers .utils .torch_utils  import  randn_tensor 
3132from  diffusers .video_processor  import  VideoProcessor 
32- from  diffusers .pipelines .pipeline_utils  import  DiffusionPipeline 
33- from  diffusers .pipelines .wan .pipeline_output  import  WanPipelineOutput 
3433
3534
3635if  is_torch_xla_available ():
6261
6362        >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." 
6463        >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" 
65-      
64+ 
6665        >>> # Configure STG mode options 
6766        >>> stg_applied_layers_idx = [8] # Layer indices from 0 to 39 for 14b or 0 to 29 for 1.3b 
6867        >>> stg_scale = 1.0 # Set 0.0 for CFG 
@@ -98,6 +97,7 @@ def prompt_clean(text):
9897    text  =  whitespace_clean (basic_clean (text ))
9998    return  text 
10099
100+ 
101101def  forward_with_stg (
102102    self ,
103103    hidden_states : torch .Tensor ,
@@ -107,35 +107,35 @@ def forward_with_stg(
107107) ->  torch .Tensor :
108108    return  hidden_states 
109109
110+ 
110111def  forward_without_stg (
111-          self ,
112-          hidden_states : torch .Tensor ,
113-          encoder_hidden_states : torch .Tensor ,
114-          temb : torch .Tensor ,
115-          rotary_emb : torch .Tensor ,
116-      ) ->  torch .Tensor :
117-          shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa  =  (
118-              self .scale_shift_table  +  temb .float ()
119-          ).chunk (6 , dim = 1 )
120- 
121-          # 1. Self-attention 
122-          norm_hidden_states  =  (self .norm1 (hidden_states .float ()) *  (1  +  scale_msa ) +  shift_msa ).type_as (hidden_states )
123-          attn_output  =  self .attn1 (hidden_states = norm_hidden_states , rotary_emb = rotary_emb )
124-          hidden_states  =  (hidden_states .float () +  attn_output  *  gate_msa ).type_as (hidden_states )
125- 
126-          # 2. Cross-attention 
127-          norm_hidden_states  =  self .norm2 (hidden_states .float ()).type_as (hidden_states )
128-          attn_output  =  self .attn2 (hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states )
129-          hidden_states  =  hidden_states  +  attn_output 
130- 
131-          # 3. Feed-forward 
132-          norm_hidden_states  =  (self .norm3 (hidden_states .float ()) *  (1  +  c_scale_msa ) +  c_shift_msa ).type_as (
133-              hidden_states 
134-          )
135-          ff_output   =   self . ffn ( norm_hidden_states ) 
136-          hidden_states   =  ( hidden_states . float ()  +   ff_output . float ()  *   c_gate_msa ). type_as ( hidden_states ) 
112+     self ,
113+     hidden_states : torch .Tensor ,
114+     encoder_hidden_states : torch .Tensor ,
115+     temb : torch .Tensor ,
116+     rotary_emb : torch .Tensor ,
117+ ) ->  torch .Tensor :
118+     shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa  =  (
119+         self .scale_shift_table  +  temb .float ()
120+     ).chunk (6 , dim = 1 )
121+ 
122+     # 1. Self-attention 
123+     norm_hidden_states  =  (self .norm1 (hidden_states .float ()) *  (1  +  scale_msa ) +  shift_msa ).type_as (hidden_states )
124+     attn_output  =  self .attn1 (hidden_states = norm_hidden_states , rotary_emb = rotary_emb )
125+     hidden_states  =  (hidden_states .float () +  attn_output  *  gate_msa ).type_as (hidden_states )
126+ 
127+     # 2. Cross-attention 
128+     norm_hidden_states  =  self .norm2 (hidden_states .float ()).type_as (hidden_states )
129+     attn_output  =  self .attn2 (hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states )
130+     hidden_states  =  hidden_states  +  attn_output 
131+ 
132+     # 3. Feed-forward 
133+     norm_hidden_states  =  (self .norm3 (hidden_states .float ()) *  (1  +  c_scale_msa ) +  c_shift_msa ).type_as (hidden_states ) 
134+     ff_output   =   self . ffn ( norm_hidden_states ) 
135+     hidden_states   =  ( hidden_states . float ()  +   ff_output . float ()  *   c_gate_msa ). type_as ( hidden_states )
136+ 
137+     return   hidden_states 
137138
138-         return  hidden_states 
139139
140140class  WanSTGPipeline (DiffusionPipeline , WanLoraLoaderMixin ):
141141    r""" 
@@ -386,7 +386,7 @@ def guidance_scale(self):
386386    @property  
387387    def  do_classifier_free_guidance (self ):
388388        return  self ._guidance_scale  >  1.0 
389-      
389+ 
390390    @property  
391391    def  do_spatio_temporal_guidance (self ):
392392        return  self ._stg_scale  >  0.0 
@@ -577,9 +577,7 @@ def __call__(
577577
578578                if  self .do_spatio_temporal_guidance :
579579                    for  idx , block  in  enumerate (self .transformer .blocks ):
580-                         block .forward  =  types .MethodType (
581-                                 forward_without_stg , block 
582-                             )
580+                         block .forward  =  types .MethodType (forward_without_stg , block )
583581
584582                noise_pred  =  self .transformer (
585583                    hidden_states = latent_model_input ,
@@ -600,17 +598,19 @@ def __call__(
600598                    if  self .do_spatio_temporal_guidance :
601599                        for  idx , block  in  enumerate (self .transformer .blocks ):
602600                            if  idx  in  stg_applied_layers_idx :
603-                                 block .forward  =  types .MethodType (
604-                                         forward_with_stg , block 
605-                                     )
601+                                 block .forward  =  types .MethodType (forward_with_stg , block )
606602                        noise_perturb  =  self .transformer (
607603                            hidden_states = latent_model_input ,
608604                            timestep = timestep ,
609605                            encoder_hidden_states = prompt_embeds ,
610606                            attention_kwargs = attention_kwargs ,
611607                            return_dict = False ,
612608                        )[0 ]
613-                         noise_pred  =  noise_uncond  +  guidance_scale  *  (noise_pred  -  noise_uncond ) +  self ._stg_scale  *  (noise_pred  -  noise_perturb )
609+                         noise_pred  =  (
610+                             noise_uncond 
611+                             +  guidance_scale  *  (noise_pred  -  noise_uncond )
612+                             +  self ._stg_scale  *  (noise_pred  -  noise_perturb )
613+                         )
614614                    else :
615615                        noise_pred  =  noise_uncond  +  guidance_scale  *  (noise_pred  -  noise_uncond )
616616
0 commit comments