Skip to content

Commit 2939ba0

Browse files
Apply style fixes
1 parent d222503 commit 2939ba0

File tree

1 file changed

+40
-40
lines changed

1 file changed

+40
-40
lines changed

examples/community/pipeline_stg_wan.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
# limitations under the License.
1414

1515
import html
16-
1716
import types
18-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17+
from typing import Any, Callable, Dict, List, Optional, Union
1918

2019
import ftfy
2120
import regex as re
@@ -25,12 +24,12 @@
2524
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
2625
from diffusers.loaders import WanLoraLoaderMixin
2726
from diffusers.models import AutoencoderKLWan, WanTransformer3DModel
27+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28+
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
2829
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
2930
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
3031
from diffusers.utils.torch_utils import randn_tensor
3132
from diffusers.video_processor import VideoProcessor
32-
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
33-
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
3433

3534

3635
if is_torch_xla_available():
@@ -62,7 +61,7 @@
6261
6362
>>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
6463
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
65-
64+
6665
>>> # Configure STG mode options
6766
>>> stg_applied_layers_idx = [8] # Layer indices from 0 to 39 for 14b or 0 to 29 for 1.3b
6867
>>> stg_scale = 1.0 # Set 0.0 for CFG
@@ -98,6 +97,7 @@ def prompt_clean(text):
9897
text = whitespace_clean(basic_clean(text))
9998
return text
10099

100+
101101
def forward_with_stg(
102102
self,
103103
hidden_states: torch.Tensor,
@@ -107,35 +107,35 @@ def forward_with_stg(
107107
) -> torch.Tensor:
108108
return hidden_states
109109

110+
110111
def forward_without_stg(
111-
self,
112-
hidden_states: torch.Tensor,
113-
encoder_hidden_states: torch.Tensor,
114-
temb: torch.Tensor,
115-
rotary_emb: torch.Tensor,
116-
) -> torch.Tensor:
117-
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
118-
self.scale_shift_table + temb.float()
119-
).chunk(6, dim=1)
120-
121-
# 1. Self-attention
122-
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
123-
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
124-
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
125-
126-
# 2. Cross-attention
127-
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
128-
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
129-
hidden_states = hidden_states + attn_output
130-
131-
# 3. Feed-forward
132-
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
133-
hidden_states
134-
)
135-
ff_output = self.ffn(norm_hidden_states)
136-
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
112+
self,
113+
hidden_states: torch.Tensor,
114+
encoder_hidden_states: torch.Tensor,
115+
temb: torch.Tensor,
116+
rotary_emb: torch.Tensor,
117+
) -> torch.Tensor:
118+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
119+
self.scale_shift_table + temb.float()
120+
).chunk(6, dim=1)
121+
122+
# 1. Self-attention
123+
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
124+
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
125+
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
126+
127+
# 2. Cross-attention
128+
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
129+
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
130+
hidden_states = hidden_states + attn_output
131+
132+
# 3. Feed-forward
133+
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
134+
ff_output = self.ffn(norm_hidden_states)
135+
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
136+
137+
return hidden_states
137138

138-
return hidden_states
139139

140140
class WanSTGPipeline(DiffusionPipeline, WanLoraLoaderMixin):
141141
r"""
@@ -386,7 +386,7 @@ def guidance_scale(self):
386386
@property
387387
def do_classifier_free_guidance(self):
388388
return self._guidance_scale > 1.0
389-
389+
390390
@property
391391
def do_spatio_temporal_guidance(self):
392392
return self._stg_scale > 0.0
@@ -577,9 +577,7 @@ def __call__(
577577

578578
if self.do_spatio_temporal_guidance:
579579
for idx, block in enumerate(self.transformer.blocks):
580-
block.forward = types.MethodType(
581-
forward_without_stg, block
582-
)
580+
block.forward = types.MethodType(forward_without_stg, block)
583581

584582
noise_pred = self.transformer(
585583
hidden_states=latent_model_input,
@@ -600,17 +598,19 @@ def __call__(
600598
if self.do_spatio_temporal_guidance:
601599
for idx, block in enumerate(self.transformer.blocks):
602600
if idx in stg_applied_layers_idx:
603-
block.forward = types.MethodType(
604-
forward_with_stg, block
605-
)
601+
block.forward = types.MethodType(forward_with_stg, block)
606602
noise_perturb = self.transformer(
607603
hidden_states=latent_model_input,
608604
timestep=timestep,
609605
encoder_hidden_states=prompt_embeds,
610606
attention_kwargs=attention_kwargs,
611607
return_dict=False,
612608
)[0]
613-
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + self._stg_scale * (noise_pred - noise_perturb)
609+
noise_pred = (
610+
noise_uncond
611+
+ guidance_scale * (noise_pred - noise_uncond)
612+
+ self._stg_scale * (noise_pred - noise_perturb)
613+
)
614614
else:
615615
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
616616

0 commit comments

Comments
 (0)