55from typing import Dict , List , Tuple , Union , Optional
66from PIL import Image
77
8- from diffsynth_engine .configs import BaseConfig , BaseStateDicts , LoraConfig
8+ from diffsynth_engine .configs import (
9+ BaseConfig ,
10+ BaseStateDicts ,
11+ LoraConfig ,
12+ AttnImpl ,
13+ SpargeAttentionParams ,
14+ VideoSparseAttentionParams ,
15+ )
16+ from diffsynth_engine .models .basic .video_sparse_attention import get_vsa_kwargs
917from diffsynth_engine .utils .offload import enable_sequential_cpu_offload , offload_model_to_dict , restore_model_from_dict
1018from diffsynth_engine .utils .fp8_linear import enable_fp8_autocast
1119from diffsynth_engine .utils .gguf import load_gguf_checkpoint
@@ -33,6 +41,7 @@ def __init__(
3341 dtype = torch .float16 ,
3442 ):
3543 super ().__init__ ()
44+ self .config = None
3645 self .vae_tiled = vae_tiled
3746 self .vae_tile_size = vae_tile_size
3847 self .vae_tile_stride = vae_tile_stride
@@ -48,7 +57,7 @@ def from_pretrained(cls, model_path_or_config: str | BaseConfig) -> "BasePipelin
4857 raise NotImplementedError ()
4958
5059 @classmethod
51- def from_state_dict (cls , state_dicts : BaseStateDicts , pipeline_config : BaseConfig ) -> "BasePipeline" :
60+ def from_state_dict (cls , state_dicts : BaseStateDicts , config : BaseConfig ) -> "BasePipeline" :
5261 raise NotImplementedError ()
5362
5463 def update_weights (self , state_dicts : BaseStateDicts ) -> None :
@@ -260,6 +269,25 @@ def prepare_latents(
260269 )
261270 return init_latents , latents , sigmas , timesteps
262271
272+ def get_attn_kwargs (self , latents : torch .Tensor ) -> Dict :
273+ attn_kwargs = {"attn_impl" : self .config .dit_attn_impl .value }
274+ if isinstance (self .config .attn_params , SpargeAttentionParams ):
275+ assert self .config .dit_attn_impl == AttnImpl .SPARGE
276+ attn_kwargs .update (
277+ {
278+ "smooth_k" : self .config .attn_params .smooth_k ,
279+ "simthreshd1" : self .config .attn_params .simthreshd1 ,
280+ "cdfthreshd" : self .config .attn_params .cdfthreshd ,
281+ "pvthreshd" : self .config .attn_params .pvthreshd ,
282+ }
283+ )
284+ elif isinstance (self .config .attn_params , VideoSparseAttentionParams ):
285+ assert self .config .dit_attn_impl == AttnImpl .VSA
286+ attn_kwargs .update (
287+ get_vsa_kwargs (latents .shape [2 :], (1 , 2 , 2 ), self .config .attn_params .sparsity , device = self .device )
288+ )
289+ return attn_kwargs
290+
263291 def eval (self ):
264292 for model_name in self .model_names :
265293 model = getattr (self , model_name )
0 commit comments