Skip to content

Commit 1f969ae

Browse files
authored
Fix/import vsa (#200)
* fix cycle import vsa * fix cycle import parallel
1 parent 0590444 commit 1f969ae

File tree

10 files changed

+71
-45
lines changed

10 files changed

+71
-45
lines changed

diffsynth_engine/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
WanStateDicts,
1313
QwenImageStateDicts,
1414
AttnImpl,
15+
SpargeAttentionParams,
16+
VideoSparseAttentionParams,
17+
LoraConfig,
1518
ControlNetParams,
1619
ControlType,
1720
QwenImageControlNetParams,
1821
QwenImageControlType,
19-
LoraConfig,
2022
)
2123
from .pipelines import (
2224
SDImagePipeline,
@@ -59,6 +61,9 @@
5961
"WanStateDicts",
6062
"QwenImageStateDicts",
6163
"AttnImpl",
64+
"SpargeAttentionParams",
65+
"VideoSparseAttentionParams",
66+
"LoraConfig",
6267
"ControlNetParams",
6368
"ControlType",
6469
"QwenImageControlNetParams",
@@ -79,7 +84,6 @@
7984
"FluxIPAdapterRefTool",
8085
"FluxReplaceByControlTool",
8186
"FluxReduxRefTool",
82-
"LoraConfig",
8387
"fetch_model",
8488
"fetch_modelscope_model",
8589
"register_fetch_modelscope_model",

diffsynth_engine/configs/__init__.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717
WanStateDicts,
1818
WanS2VStateDicts,
1919
QwenImageStateDicts,
20-
LoraConfig,
2120
AttnImpl,
21+
SpargeAttentionParams,
22+
VideoSparseAttentionParams,
23+
LoraConfig,
2224
)
2325
from .controlnet import (
2426
ControlType,
2527
ControlNetParams,
26-
QwenImageControlNetParams,
2728
QwenImageControlType,
29+
QwenImageControlNetParams,
2830
)
2931

3032
__all__ = [
@@ -46,10 +48,12 @@
4648
"WanStateDicts",
4749
"WanS2VStateDicts",
4850
"QwenImageStateDicts",
49-
"QwenImageControlType",
50-
"QwenImageControlNetParams",
51+
"AttnImpl",
52+
"SpargeAttentionParams",
53+
"VideoSparseAttentionParams",
54+
"LoraConfig",
5155
"ControlType",
5256
"ControlNetParams",
53-
"LoraConfig",
54-
"AttnImpl",
57+
"QwenImageControlType",
58+
"QwenImageControlNetParams",
5559
]

diffsynth_engine/configs/pipeline.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import List, Dict, Tuple, Optional
66

77
from diffsynth_engine.configs.controlnet import ControlType
8-
from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs
98

109

1110
@dataclass
@@ -52,23 +51,6 @@ class AttentionConfig:
5251
dit_attn_impl: AttnImpl = AttnImpl.AUTO
5352
attn_params: Optional[SpargeAttentionParams | VideoSparseAttentionParams] = None
5453

55-
def get_attn_kwargs(self, latents: torch.Tensor, device: str) -> Dict:
56-
attn_kwargs = {"attn_impl": self.dit_attn_impl.value}
57-
if isinstance(self.attn_params, SpargeAttentionParams):
58-
assert self.dit_attn_impl == AttnImpl.SPARGE
59-
attn_kwargs.update(
60-
{
61-
"smooth_k": self.attn_params.smooth_k,
62-
"simthreshd1": self.attn_params.simthreshd1,
63-
"cdfthreshd": self.attn_params.cdfthreshd,
64-
"pvthreshd": self.attn_params.pvthreshd,
65-
}
66-
)
67-
elif isinstance(self.attn_params, VideoSparseAttentionParams):
68-
assert self.dit_attn_impl == AttnImpl.VSA
69-
attn_kwargs.update(get_vsa_kwargs(latents.shape[2:], (1, 2, 2), self.attn_params.sparsity, device=device))
70-
return attn_kwargs
71-
7254

7355
@dataclass
7456
class OptimizationConfig:

diffsynth_engine/models/basic/video_sparse_attention.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
import math
33
import functools
44

5-
from vsa import video_sparse_attn as vsa_core
5+
from diffsynth_engine.utils.flag import VIDEO_SPARSE_ATTN_AVAILABLE
66
from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size
77

8+
if VIDEO_SPARSE_ATTN_AVAILABLE:
9+
from vsa import video_sparse_attn as vsa_core
10+
811
VSA_TILE_SIZE = (4, 4, 4)
912

1013

diffsynth_engine/pipelines/base.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55
from typing import Dict, List, Tuple, Union, Optional
66
from PIL import Image
77

8-
from diffsynth_engine.configs import BaseConfig, BaseStateDicts, LoraConfig
8+
from diffsynth_engine.configs import (
9+
BaseConfig,
10+
BaseStateDicts,
11+
LoraConfig,
12+
AttnImpl,
13+
SpargeAttentionParams,
14+
VideoSparseAttentionParams,
15+
)
16+
from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs
917
from diffsynth_engine.utils.offload import enable_sequential_cpu_offload, offload_model_to_dict, restore_model_from_dict
1018
from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
1119
from diffsynth_engine.utils.gguf import load_gguf_checkpoint
@@ -33,6 +41,7 @@ def __init__(
3341
dtype=torch.float16,
3442
):
3543
super().__init__()
44+
self.config = None
3645
self.vae_tiled = vae_tiled
3746
self.vae_tile_size = vae_tile_size
3847
self.vae_tile_stride = vae_tile_stride
@@ -48,7 +57,7 @@ def from_pretrained(cls, model_path_or_config: str | BaseConfig) -> "BasePipelin
4857
raise NotImplementedError()
4958

5059
@classmethod
51-
def from_state_dict(cls, state_dicts: BaseStateDicts, pipeline_config: BaseConfig) -> "BasePipeline":
60+
def from_state_dict(cls, state_dicts: BaseStateDicts, config: BaseConfig) -> "BasePipeline":
5261
raise NotImplementedError()
5362

5463
def update_weights(self, state_dicts: BaseStateDicts) -> None:
@@ -260,6 +269,25 @@ def prepare_latents(
260269
)
261270
return init_latents, latents, sigmas, timesteps
262271

272+
def get_attn_kwargs(self, latents: torch.Tensor) -> Dict:
273+
attn_kwargs = {"attn_impl": self.config.dit_attn_impl.value}
274+
if isinstance(self.config.attn_params, SpargeAttentionParams):
275+
assert self.config.dit_attn_impl == AttnImpl.SPARGE
276+
attn_kwargs.update(
277+
{
278+
"smooth_k": self.config.attn_params.smooth_k,
279+
"simthreshd1": self.config.attn_params.simthreshd1,
280+
"cdfthreshd": self.config.attn_params.cdfthreshd,
281+
"pvthreshd": self.config.attn_params.pvthreshd,
282+
}
283+
)
284+
elif isinstance(self.config.attn_params, VideoSparseAttentionParams):
285+
assert self.config.dit_attn_impl == AttnImpl.VSA
286+
attn_kwargs.update(
287+
get_vsa_kwargs(latents.shape[2:], (1, 2, 2), self.config.attn_params.sparsity, device=self.device)
288+
)
289+
return attn_kwargs
290+
263291
def eval(self):
264292
for model_name in self.model_names:
265293
model = getattr(self, model_name)

diffsynth_engine/pipelines/flux_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ def predict_noise(
751751
latents = latents.to(self.dtype)
752752
self.load_models_to_device(["dit"])
753753

754-
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
754+
attn_kwargs = self.get_attn_kwargs(latents)
755755
noise_pred = self.dit(
756756
hidden_states=latents,
757757
timestep=timestep,
@@ -886,7 +886,7 @@ def predict_multicontrolnet(
886886
empty_cache()
887887
param.model.to(self.device)
888888

889-
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
889+
attn_kwargs = self.get_attn_kwargs(latents)
890890
double_block_output, single_block_output = param.model(
891891
hidden_states=latents,
892892
control_condition=control_condition,

diffsynth_engine/pipelines/qwen_image.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig) ->
208208
)
209209
if config.load_encoder:
210210
logger.info(f"loading state dict from {config.encoder_path} ...")
211-
encoder_state_dict = cls.load_model_checkpoint(config.encoder_path, device="cpu", dtype=config.encoder_dtype)
211+
encoder_state_dict = cls.load_model_checkpoint(
212+
config.encoder_path, device="cpu", dtype=config.encoder_dtype
213+
)
212214

213215
state_dicts = QwenImageStateDicts(
214216
model=model_state_dict,
@@ -547,7 +549,7 @@ def predict_noise(
547549
entity_masks: Optional[List[torch.Tensor]] = None,
548550
):
549551
self.load_models_to_device(["dit"])
550-
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
552+
attn_kwargs = self.get_attn_kwargs(latents)
551553
noise_pred = self.dit(
552554
image=latents,
553555
edit=image_latents,

diffsynth_engine/pipelines/wan_s2v.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def predict_noise(
394394
void_audio_input: torch.Tensor | None = None,
395395
):
396396
latents = latents.to(dtype=self.config.model_dtype, device=self.device)
397-
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
397+
attn_kwargs = self.get_attn_kwargs(latents)
398398

399399
noise_pred = model(
400400
x=latents,

diffsynth_engine/pipelines/wan_video.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def load_loras(
144144
lora_list: List[Tuple[str, float]],
145145
fused: bool = True,
146146
save_original_weight: bool = False,
147-
lora_converter: Optional[WanLoRAConverter] = None
147+
lora_converter: Optional[WanLoRAConverter] = None,
148148
):
149149
assert self.config.tp_degree is None or self.config.tp_degree == 1, (
150150
"load LoRA is not allowed when tensor parallel is enabled; "
@@ -156,11 +156,15 @@ def load_loras(
156156
)
157157
super().load_loras(lora_list, fused, save_original_weight, lora_converter)
158158

159-
def load_loras_low_noise(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
159+
def load_loras_low_noise(
160+
self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False
161+
):
160162
assert self.dit2 is not None, "low noise LoRA can only be applied to Wan2.2"
161163
self.load_loras(lora_list, fused, save_original_weight, self.low_noise_lora_converter)
162164

163-
def load_loras_high_noise(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
165+
def load_loras_high_noise(
166+
self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False
167+
):
164168
assert self.dit2 is not None, "high noise LoRA can only be applied to Wan2.2"
165169
self.load_loras(lora_list, fused, save_original_weight)
166170

@@ -323,7 +327,7 @@ def predict_noise_with_cfg(
323327

324328
def predict_noise(self, model, latents, image_clip_feature, image_y, timestep, context):
325329
latents = latents.to(dtype=self.config.model_dtype, device=self.device)
326-
attn_kwargs = self.config.get_attn_kwargs(latents, self.device)
330+
attn_kwargs = self.get_attn_kwargs(latents)
327331

328332
noise_pred = model(
329333
x=latents,

diffsynth_engine/utils/parallel.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from queue import Empty
2020

2121
import diffsynth_engine.models.basic.attention as attention_ops
22-
from diffsynth_engine.models import PreTrainedModel
23-
from diffsynth_engine.pipelines import BasePipeline
2422
from diffsynth_engine.utils.platform import empty_cache
2523
from diffsynth_engine.utils import logging
2624

@@ -300,14 +298,15 @@ def _worker_loop(
300298
world_size=world_size,
301299
)
302300

303-
def wrap_for_parallel(module: Union[PreTrainedModel, BasePipeline]):
304-
if isinstance(module, BasePipeline):
305-
for model_name in module.model_names:
306-
if isinstance(submodule := getattr(module, model_name), PreTrainedModel):
301+
def wrap_for_parallel(module):
302+
if hasattr(module, "model_names"):
303+
for model_name in getattr(module, "model_names"):
304+
submodule = getattr(module, model_name)
305+
if getattr(submodule, "_supports_parallelization", False):
307306
setattr(module, model_name, wrap_for_parallel(submodule))
308307
return module
309308

310-
if not module._supports_parallelization:
309+
if not getattr(module, "_supports_parallelization", False):
311310
return module
312311

313312
if tp_degree > 1:

0 commit comments

Comments
 (0)