Skip to content

Commit cd4e102

Browse files
authored
new PipelineConfig for initialization (#123)
* fix sparge_attn in long_context_attention * new PipelineConfig for initialization
1 parent dd2029e commit cd4e102

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+751
-620
lines changed

diffsynth_engine/__init__.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
from .configs import (
2+
SDPipelineConfig,
3+
SDXLPipelineConfig,
4+
FluxPipelineConfig,
5+
WanPipelineConfig,
6+
)
17
from .pipelines import (
28
FluxImagePipeline,
39
SDXLImagePipeline,
410
SDImagePipeline,
511
WanVideoPipeline,
6-
FluxModelConfig,
7-
SDXLModelConfig,
8-
SDModelConfig,
9-
WanModelConfig,
1012
ControlNetParams,
1113
)
1214
from .models.flux import FluxControlNet, FluxIPAdapter, FluxRedux
@@ -23,6 +25,10 @@
2325
)
2426

2527
__all__ = [
28+
"SDPipelineConfig",
29+
"SDXLPipelineConfig",
30+
"FluxPipelineConfig",
31+
"WanPipelineConfig",
2632
"FluxImagePipeline",
2733
"FluxControlNet",
2834
"FluxIPAdapter",
@@ -32,10 +38,6 @@
3238
"SDXLImagePipeline",
3339
"SDImagePipeline",
3440
"WanVideoPipeline",
35-
"FluxModelConfig",
36-
"SDXLModelConfig",
37-
"SDModelConfig",
38-
"WanModelConfig",
3941
"FluxInpaintingTool",
4042
"FluxOutpaintingTool",
4143
"FluxIPAdapterRefTool",
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from .pipeline import (
2+
BaseConfig,
3+
AttentionConfig,
4+
OptimizationConfig,
5+
ParallelConfig,
6+
SDPipelineConfig,
7+
SDXLPipelineConfig,
8+
FluxPipelineConfig,
9+
WanPipelineConfig,
10+
)
11+
from .controlnet import ControlType
12+
13+
__all__ = [
14+
"BaseConfig",
15+
"AttentionConfig",
16+
"OptimizationConfig",
17+
"ParallelConfig",
18+
"SDPipelineConfig",
19+
"SDXLPipelineConfig",
20+
"FluxPipelineConfig",
21+
"WanPipelineConfig",
22+
"ControlType",
23+
]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from enum import Enum
2+
3+
4+
# FLUX ControlType
5+
class ControlType(Enum):
6+
normal = "normal"
7+
bfl_control = "bfl_control"
8+
bfl_fill = "bfl_fill"
9+
bfl_kontext = "bfl_kontext"
10+
11+
def get_in_channel(self):
12+
if self in [ControlType.normal, ControlType.bfl_kontext]:
13+
return 64
14+
elif self == ControlType.bfl_control:
15+
return 128
16+
elif self == ControlType.bfl_fill:
17+
return 384
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import os
2+
import torch
3+
from dataclasses import dataclass, field
4+
from typing import List, Tuple, Optional
5+
6+
from diffsynth_engine.configs.controlnet import ControlType
7+
8+
9+
@dataclass
10+
class BaseConfig:
11+
model_path: str | os.PathLike | List[str | os.PathLike]
12+
model_dtype: torch.dtype
13+
batch_cfg: bool = False
14+
vae_tiled: bool = False
15+
vae_tile_size: int | Tuple[int, int] = 256
16+
vae_tile_stride: int | Tuple[int, int] = 256
17+
device: str = "cuda"
18+
offload_mode: Optional[str] = None
19+
20+
21+
@dataclass
22+
class AttentionConfig:
23+
dit_attn_impl: str = "auto"
24+
# Sparge Attention
25+
sparge_smooth_k: bool = True
26+
sparge_cdfthreshd: float = 0.6
27+
sparge_simthreshd1: float = 0.98
28+
sparge_pvthreshd: float = 50.0
29+
30+
31+
@dataclass
32+
class OptimizationConfig:
33+
use_fp8_linear: bool = False
34+
use_fbcache: bool = False
35+
fbcache_relative_l1_threshold: float = 0.05
36+
37+
38+
@dataclass
39+
class ParallelConfig:
40+
parallelism: int = 1
41+
use_cfg_parallel: bool = False
42+
cfg_degree: Optional[int] = None
43+
sp_ulysses_degree: Optional[int] = None
44+
sp_ring_degree: Optional[int] = None
45+
tp_degree: Optional[int] = None
46+
use_fsdp: bool = False
47+
48+
49+
@dataclass
50+
class SDPipelineConfig(BaseConfig):
51+
model_path: str | os.PathLike | List[str | os.PathLike]
52+
clip_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
53+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
54+
model_dtype: torch.dtype = torch.float16
55+
clip_dtype: torch.dtype = torch.float16
56+
vae_dtype: torch.dtype = torch.float32
57+
58+
@classmethod
59+
def basic_config(
60+
cls,
61+
model_path: str | os.PathLike | List[str | os.PathLike],
62+
device: str = "cuda",
63+
offload_mode: Optional[str] = None,
64+
) -> "SDPipelineConfig":
65+
return cls(
66+
model_path=model_path,
67+
device=device,
68+
offload_mode=offload_mode,
69+
)
70+
71+
72+
@dataclass
73+
class SDXLPipelineConfig(BaseConfig):
74+
model_path: str | os.PathLike | List[str | os.PathLike]
75+
clip_l_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
76+
clip_g_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
77+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
78+
model_dtype: torch.dtype = torch.float16
79+
clip_l_dtype: torch.dtype = torch.float16
80+
clip_g_dtype: torch.dtype = torch.float16
81+
vae_dtype: torch.dtype = torch.float32
82+
83+
@classmethod
84+
def basic_config(
85+
cls,
86+
model_path: str | os.PathLike | List[str | os.PathLike],
87+
device: str = "cuda",
88+
offload_mode: Optional[str] = None,
89+
) -> "SDXLPipelineConfig":
90+
return cls(
91+
model_path=model_path,
92+
device=device,
93+
offload_mode=offload_mode,
94+
)
95+
96+
97+
@dataclass
98+
class FluxPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
99+
model_path: str | os.PathLike | List[str | os.PathLike]
100+
clip_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
101+
t5_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
102+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
103+
model_dtype: torch.dtype = torch.bfloat16
104+
clip_dtype: torch.dtype = torch.bfloat16
105+
t5_dtype: torch.dtype = torch.bfloat16
106+
vae_dtype: torch.dtype = torch.bfloat16
107+
108+
load_text_encoder: bool = True
109+
control_type: ControlType = ControlType.normal
110+
111+
@classmethod
112+
def basic_config(
113+
cls,
114+
model_path: str | os.PathLike | List[str | os.PathLike],
115+
device: str = "cuda",
116+
parallelism: int = 1,
117+
offload_mode: Optional[str] = None,
118+
) -> "FluxPipelineConfig":
119+
return cls(
120+
model_path=model_path,
121+
device=device,
122+
parallelism=parallelism,
123+
use_fsdp=True,
124+
offload_mode=offload_mode,
125+
)
126+
127+
def __post_init__(self):
128+
init_parallel_config(self)
129+
130+
131+
@dataclass
132+
class WanPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
133+
model_path: str | os.PathLike | List[str | os.PathLike]
134+
t5_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
135+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
136+
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
137+
model_dtype: torch.dtype = torch.bfloat16
138+
t5_dtype: torch.dtype = torch.bfloat16
139+
vae_dtype: torch.dtype = torch.bfloat16
140+
image_encoder_dtype: torch.dtype = torch.bfloat16
141+
142+
shift: Optional[float] = field(default=None, init=False) # RecifitedFlowScheduler shift factor, set by model type
143+
144+
# override BaseConfig
145+
vae_tiled: bool = True
146+
vae_tile_size: Tuple[int, int] = (34, 34)
147+
vae_tile_stride: Tuple[int, int] = (18, 16)
148+
149+
@classmethod
150+
def basic_config(
151+
cls,
152+
model_path: str | os.PathLike | List[str | os.PathLike],
153+
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
154+
device: str = "cuda",
155+
parallelism: int = 1,
156+
offload_mode: Optional[str] = None,
157+
) -> "WanPipelineConfig":
158+
return cls(
159+
model_path=model_path,
160+
image_encoder_path=image_encoder_path,
161+
device=device,
162+
parallelism=parallelism,
163+
use_cfg_parallel=True,
164+
use_fsdp=True,
165+
offload_mode=offload_mode,
166+
)
167+
168+
def __post_init__(self):
169+
init_parallel_config(self)
170+
171+
172+
def init_parallel_config(config: FluxPipelineConfig | WanPipelineConfig):
173+
assert config.parallelism in (1, 2, 4, 8), "parallelism must be 1, 2, 4 or 8"
174+
config.batch_cfg = True if config.parallelism > 1 and config.use_cfg_parallel else config.batch_cfg
175+
176+
if config.use_cfg_parallel is True and config.cfg_degree is not None:
177+
raise ValueError("use_cfg_parallel and cfg_degree should not be specified together")
178+
config.cfg_degree = (2 if config.use_cfg_parallel else 1) if config.cfg_degree is None else config.cfg_degree
179+
180+
if config.tp_degree is not None:
181+
assert config.sp_ulysses_degree is None and config.sp_ring_degree is None, (
182+
"not allowed to enable sequence parallel and tensor parallel together; "
183+
"either set sp_ulysses_degree=None, sp_ring_degree=None or set tp_degree=None during pipeline initialization"
184+
)
185+
assert config.use_fsdp is False, (
186+
"not allowed to enable fully sharded data parallel and tensor parallel together; "
187+
"either set use_fsdp=False or set tp_degree=None during pipeline initialization"
188+
)
189+
assert config.parallelism == config.cfg_degree * config.tp_degree, (
190+
f"parallelism ({config.parallelism}) must be equal to cfg_degree ({config.cfg_degree}) * tp_degree ({config.tp_degree})"
191+
)
192+
config.sp_ulysses_degree = 1
193+
config.sp_ring_degree = 1
194+
elif config.sp_ulysses_degree is None and config.sp_ring_degree is None:
195+
# use ulysses if not specified
196+
config.sp_ulysses_degree = config.parallelism // config.cfg_degree
197+
config.sp_ring_degree = 1
198+
config.tp_degree = 1
199+
elif config.sp_ulysses_degree is not None and config.sp_ring_degree is not None:
200+
assert config.parallelism == config.cfg_degree * config.sp_ulysses_degree * config.sp_ring_degree, (
201+
f"parallelism ({config.parallelism}) must be equal to cfg_degree ({config.cfg_degree}) * "
202+
f"sp_ulysses_degree ({config.sp_ulysses_degree}) * sp_ring_degree ({config.sp_ring_degree})"
203+
)
204+
config.tp_degree = 1
205+
else:
206+
raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")

diffsynth_engine/models/basic/attention.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,33 @@ def sage_attn(q, k, v, attn_mask=None, scale=None):
6161

6262
if SPARGE_ATTN_AVAILABLE:
6363
from spas_sage_attn import spas_sage2_attn_meansim_cuda
64+
from spas_sage_attn.autotune import SparseAttentionMeansim
6465

65-
def sparge_attn(q, k, v, attn_mask=None, scale=None):
66+
def sparge_attn(
67+
q,
68+
k,
69+
v,
70+
attn_mask=None,
71+
scale=None,
72+
smooth_k=True,
73+
simthreshd1=0.6,
74+
cdfthreshd=0.98,
75+
pvthreshd=50,
76+
):
6677
q = q.transpose(1, 2)
6778
k = k.transpose(1, 2)
6879
v = v.transpose(1, 2)
69-
out = spas_sage2_attn_meansim_cuda(q, k, v, attn_mask=attn_mask, scale=scale)
80+
out = spas_sage2_attn_meansim_cuda(
81+
q,
82+
k,
83+
v,
84+
attn_mask=attn_mask,
85+
scale=scale,
86+
smooth_k=smooth_k,
87+
simthreshd1=simthreshd1,
88+
cdfthreshd=cdfthreshd,
89+
pvthreshd=pvthreshd,
90+
)
7091
return out.transpose(1, 2)
7192

7293

@@ -91,6 +112,7 @@ def attention(
91112
attn_impl: Optional[str] = None,
92113
attn_mask: Optional[torch.Tensor] = None,
93114
scale: Optional[float] = None,
115+
**kwargs,
94116
):
95117
"""
96118
q: [B, Lq, Nq, C1]
@@ -133,7 +155,17 @@ def attention(
133155
elif attn_impl == "sage_attn":
134156
return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
135157
elif attn_impl == "sparge_attn":
136-
return sparge_attn(q, k, v, attn_mask=attn_mask, scale=scale)
158+
return sparge_attn(
159+
q,
160+
k,
161+
v,
162+
attn_mask=attn_mask,
163+
scale=scale,
164+
smooth_k=kwargs.get("sparge_smooth_k", True),
165+
simthreshd1=kwargs.get("sparge_simthreshd1", 0.6),
166+
cdfthreshd=kwargs.get("sparge_cdfthreshd", 0.98),
167+
pvthreshd=kwargs.get("sparge_pvthreshd", 50),
168+
)
137169
else:
138170
raise ValueError(f"Invalid attention implementation: {attn_impl}")
139171

@@ -189,6 +221,7 @@ def long_context_attention(
189221
attn_impl: Optional[str] = None,
190222
attn_mask: Optional[torch.Tensor] = None,
191223
scale: Optional[float] = None,
224+
**kwargs,
192225
):
193226
"""
194227
q: [B, Lq, Nq, C1]
@@ -226,7 +259,13 @@ def long_context_attention(
226259
elif attn_impl == "sage_attn":
227260
attn_func = LongContextAttention(attn_type=AttnType.SAGE_FP8)
228261
elif attn_impl == "sparge_attn":
229-
attn_func = LongContextAttention(attn_type=AttnType.SPARSE_SAGE)
262+
attn_processor = SparseAttentionMeansim()
263+
# default args from spas_sage2_attn_meansim_cuda
264+
attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True))
265+
attn_processor.simthreshd1 = torch.tensor(kwargs.get("sparge_simthreshd1", 0.6))
266+
attn_processor.cdfthreshd = torch.tensor(kwargs.get("sparge_cdfthreshd", 0.98))
267+
attn_processor.pvthreshd = torch.tensor(kwargs.get("sparge_pvthreshd", 50))
268+
attn_func = LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)
230269
else:
231270
raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
232271
return attn_func(q, k, v, softmax_scale=scale)

0 commit comments

Comments
 (0)