Skip to content

Commit f7119c8

Browse files
authored
auto enable vsa (#203)
1 parent 6b2825e commit f7119c8

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

diffsynth_engine/pipelines/wan_s2v.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,8 @@ def _from_state_dict(
656656
)
657657

658658
with LoRAContext():
659+
cls._auto_enable_vsa(state_dicts.model, config)
660+
659661
dit = WanS2VDiT.from_state_dict(
660662
state_dicts.model,
661663
config=model_config,

diffsynth_engine/pipelines/wan_video.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from tqdm import tqdm
55
from PIL import Image
66

7-
from diffsynth_engine.configs import WanPipelineConfig, WanStateDicts
7+
from diffsynth_engine.configs import WanPipelineConfig, WanStateDicts, AttnImpl, VideoSparseAttentionParams
88
from diffsynth_engine.algorithm.noise_scheduler.flow_match import RecifitedFlowScheduler
99
from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
1010
from diffsynth_engine.models.wan.wan_dit import WanDiT
@@ -584,6 +584,8 @@ def _from_state_dict(cls, state_dicts: WanStateDicts, config: WanPipelineConfig)
584584
dit_state_dict = state_dicts.model
585585

586586
with LoRAContext():
587+
cls._auto_enable_vsa(dit_state_dict, config)
588+
587589
dit = WanDiT.from_state_dict(
588590
dit_state_dict,
589591
config=dit_config,
@@ -668,6 +670,16 @@ def _get_vae_type(vae_state_dict: Dict[str, torch.Tensor]) -> str:
668670
vae_type = "wan2.2-vae"
669671
return vae_type
670672

673+
@staticmethod
674+
def _auto_enable_vsa(state_dict: Dict[str, torch.Tensor], config: WanPipelineConfig):
675+
def has_any_key(*xs):
676+
return any(x in state_dict for x in xs)
677+
678+
if has_any_key("blocks.0.to_gate_compress.weight", "blocks.0.self_attn.gate_compress.weight"):
679+
config.dit_attn_impl = AttnImpl.VSA
680+
if config.attn_params is None:
681+
config.attn_params = VideoSparseAttentionParams(sparsity=0.9)
682+
671683
def compile(self):
672684
self.dit.compile_repeated_blocks()
673685
if self.dit2 is not None:

0 commit comments

Comments
 (0)