Skip to content

Commit 3888f1d

Browse files
1. Refactor Attention Component 2. Support downloading only specific files from modelscope (#27)
* refactor attention * bug fix * bug fix * bug fix * fix test and attention bug * support only donwload part of models in one repo * fix attn bug * ruff format * format import --------- Co-authored-by: zhuguoxuan.zgx <[email protected]>
1 parent e2115a0 commit 3888f1d

33 files changed

+375
-551
lines changed

diffsynth_engine/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
from .utils.download import fetch_model, fetch_modelscope_model, fetch_civitai_model
1212
from .utils.video import load_video, save_video
13+
1314
__all__ = [
1415
"FluxImagePipeline",
1516
"SDXLImagePipeline",
@@ -22,4 +23,6 @@
2223
"fetch_model",
2324
"fetch_modelscope_model",
2425
"fetch_civitai_model",
26+
"load_video",
27+
"save_video",
2528
]

diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@
55

66

77
class RecifitedFlowScheduler(BaseScheduler):
8-
def __init__(self,
9-
shift=1.0,
10-
sigma_min=0.001,
8+
def __init__(
9+
self,
10+
shift=1.0,
11+
sigma_min=0.001,
1112
sigma_max=1.0,
12-
num_train_timesteps=1000,
13+
num_train_timesteps=1000,
1314
use_dynamic_shifting=False,
1415
):
1516
self.shift = shift
1617
self.sigma_min = sigma_min
1718
self.sigma_max = sigma_max
18-
self.num_train_timesteps = num_train_timesteps
19-
self.use_dynamic_shifting = use_dynamic_shifting
19+
self.num_train_timesteps = num_train_timesteps
20+
self.use_dynamic_shifting = use_dynamic_shifting
2021

2122
def _sigma_to_t(self, sigma):
2223
return sigma * self.num_train_timesteps
@@ -30,19 +31,20 @@ def _time_shift(self, mu: float, sigma: float, t: torch.Tensor):
3031
def _shift_sigma(self, sigma: torch.Tensor, shift: float):
3132
return shift * sigma / (1 + (shift - 1) * sigma)
3233

33-
def schedule(self,
34-
num_inference_steps: int,
35-
mu: float | None = None,
36-
sigma_min: float | None = None,
37-
sigma_max: float | None = None
34+
def schedule(
35+
self,
36+
num_inference_steps: int,
37+
mu: float | None = None,
38+
sigma_min: float | None = None,
39+
sigma_max: float | None = None,
3840
):
3941
sigma_min = self.sigma_min if sigma_min is None else sigma_min
40-
sigma_max = self.sigma_max if sigma_max is None else sigma_max
42+
sigma_max = self.sigma_max if sigma_max is None else sigma_max
4143
sigmas = torch.linspace(sigma_max, sigma_min, num_inference_steps)
4244
if self.use_dynamic_shifting:
43-
sigmas = self._time_shift(mu, 1.0, sigmas) # FLUX
45+
sigmas = self._time_shift(mu, 1.0, sigmas) # FLUX
4446
else:
4547
sigmas = self._shift_sigma(sigmas, self.shift)
4648
timesteps = sigmas * self.num_train_timesteps
4749
sigmas = append_zero(sigmas)
48-
return sigmas, timesteps
50+
return sigmas, timesteps

diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
import torch
2-
from .linear import ScaledLinearScheduler
3-
from ..base_scheduler import append_zero
4-
import numpy as np
52

63
from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
74
from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero

diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
import torch
2-
from .linear import ScaledLinearScheduler
3-
from ..base_scheduler import append_zero
4-
import numpy as np
52

63
from diffsynth_engine.algorithm.noise_scheduler.stable_diffusion.linear import ScaledLinearScheduler
74
from diffsynth_engine.algorithm.noise_scheduler.base_scheduler import append_zero

diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
class FlowMatchEulerSampler:
5-
def initialize(self, init_latents, timesteps, sigmas, mask=None):
5+
def initialize(self, init_latents, timesteps, sigmas, mask=None):
66
self.init_latents = init_latents
77
self.timesteps = timesteps
88
self.sigmas = sigmas

diffsynth_engine/kernels/__init__.py

Whitespace-only changes.
Lines changed: 117 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,115 @@
11
import torch
22
import torch.nn as nn
33
from einops import rearrange
4-
4+
from typing import Optional
55
from diffsynth_engine.utils import logging
6+
from diffsynth_engine.utils.flag import (
7+
FLASH_ATTN_3_AVAILABLE,
8+
FLASH_ATTN_2_AVAILABLE,
9+
XFORMERS_AVAILABLE,
10+
SDPA_AVAILABLE,
11+
SAGE_ATTN_AVAILABLE,
12+
SPARGE_ATTN_AVAILABLE,
13+
)
14+
15+
if FLASH_ATTN_3_AVAILABLE:
16+
from flash_attn_interface import flash_attn_func as flash_attn3
17+
if FLASH_ATTN_2_AVAILABLE:
18+
from flash_attn import flash_attn_func as flash_attn2
19+
if XFORMERS_AVAILABLE:
20+
import xformers.ops.memory_efficient_attention as xformers_attn
21+
if SDPA_AVAILABLE:
22+
23+
def sdpa_attn(q, k, v, attn_mask=None, scale=None):
24+
q = q.transpose(1, 2)
25+
k = k.transpose(1, 2)
26+
v = v.transpose(1, 2)
27+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=scale)
28+
return out.transpose(1, 2)
29+
30+
31+
if SAGE_ATTN_AVAILABLE:
32+
from sageattention import sageattn
33+
34+
def sage_attn(q, k, v, attn_mask=None, scale=None):
35+
q = q.transpose(1, 2)
36+
k = k.transpose(1, 2)
37+
v = v.transpose(1, 2)
38+
out = sageattn(q, k, v, attn_mask=attn_mask, sm_scale=scale)
39+
return out.transpose(1, 2)
40+
41+
42+
if SPARGE_ATTN_AVAILABLE:
43+
from spas_sage_attn import spas_sage2_attn_meansim_cuda
44+
45+
def sparge_attn(self, q, k, v, attn_mask=None, scale=None):
46+
q = q.transpose(1, 2)
47+
k = k.transpose(1, 2)
48+
v = v.transpose(1, 2)
49+
out = spas_sage2_attn_meansim_cuda(q, k, v, attn_mask=attn_mask, scale=scale)
50+
return out.transpose(1, 2)
51+
652

753
logger = logging.get_logger(__name__)
854

955

56+
def eager_attn(query, key, value, attn_mask=None, scale=None):
57+
scale = 1 / query.shape[-1] ** 0.5 if scale is None else scale
58+
query = query * scale
59+
attn = torch.matmul(query, key.transpose(-2, -1))
60+
if attn_mask is not None:
61+
attn = attn + attn_mask
62+
attn = attn.softmax(-1)
63+
return attn @ value
64+
65+
66+
def attention(q, k, v, attn_mask=None, attn_impl: Optional[str] = None, scale: Optional[float] = None):
67+
"""
68+
q: [B, Lq, Nq, C1]
69+
k: [B, Lk, Nk, C1]
70+
v: [B, Lk, Nk, C2]
71+
"""
72+
assert attn_impl in [
73+
None,
74+
"auto",
75+
"eager",
76+
"flash_attn_2",
77+
"flash_attn_3",
78+
"xformers",
79+
"sdpa",
80+
"sage_attn",
81+
"sparge_attn",
82+
]
83+
if attn_impl is None or attn_impl == "auto":
84+
if FLASH_ATTN_3_AVAILABLE:
85+
return flash_attn3(q, k, v, softmax_scale=scale)
86+
elif FLASH_ATTN_2_AVAILABLE:
87+
return flash_attn2(q, k, v, softmax_scale=scale)
88+
elif XFORMERS_AVAILABLE:
89+
return xformers_attn(q, k, v, attn_bias=attn_mask, scale=scale)
90+
elif SDPA_AVAILABLE:
91+
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
92+
else:
93+
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
94+
else:
95+
if attn_impl == "eager":
96+
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
97+
elif attn_impl == "flash_attn_3":
98+
return flash_attn3(q, k, v, softmax_scale=scale)
99+
elif attn_impl == "flash_attn_2":
100+
return flash_attn2(q, k, v, softmax_scale=scale)
101+
elif attn_impl == "xformers":
102+
return xformers_attn(q, k, v, attn_bias=attn_mask, scale=scale)
103+
elif attn_impl == "sdpa":
104+
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
105+
elif attn_impl == "sage_attn":
106+
return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
107+
elif attn_impl == "sparge_attn":
108+
return sparge_attn(q, k, v, attn_mask=attn_mask, scale=scale)
109+
else:
110+
raise ValueError(f"Invalid attention implementation: {attn_impl}")
111+
112+
10113
class Attention(nn.Module):
11114
def __init__(
12115
self,
@@ -18,7 +121,7 @@ def __init__(
18121
bias_kv=False,
19122
bias_out=False,
20123
scale=None,
21-
attn_implementation: str = "sdpa",
124+
attn_impl: Optional[str] = None,
22125
device: str = "cuda:0",
23126
dtype: torch.dtype = torch.float16,
24127
):
@@ -32,106 +135,20 @@ def __init__(
32135
self.to_k = nn.Linear(kv_dim, dim_inner, bias=bias_kv, device=device, dtype=dtype)
33136
self.to_v = nn.Linear(kv_dim, dim_inner, bias=bias_kv, device=device, dtype=dtype)
34137
self.to_out = nn.Linear(dim_inner, q_dim, bias=bias_out, device=device, dtype=dtype)
35-
138+
self.attn_impl = attn_impl
36139
self.scale = scale
37-
self.attn_implementation = self._get_actual_attn_implementation(attn_implementation)
38-
39-
@staticmethod
40-
def _get_actual_attn_implementation(attn_implementation):
41-
supported_implementations = ("xformers", "sdpa", "eager")
42-
if attn_implementation not in supported_implementations:
43-
raise ValueError(
44-
f"attn_implementation must be one of {supported_implementations}, but got '{attn_implementation}'"
45-
)
46-
47-
actual_implementation = "eager" if attn_implementation == "eager" else ""
48-
if attn_implementation == "xformers":
49-
try:
50-
from xformers.ops import memory_efficient_attention
51-
52-
actual_implementation = "xformers"
53-
except ImportError:
54-
pass
55-
if not actual_implementation or attn_implementation == "sdpa":
56-
use_mps = torch.backends.mps.is_available()
57-
if hasattr(torch.nn.functional, "scaled_dot_product_attention") and not use_mps:
58-
actual_implementation = "sdpa"
59-
60-
if actual_implementation != attn_implementation:
61-
warning_msg = (
62-
"xformers is not supported on this platform"
63-
if attn_implementation == "xformers"
64-
else "torch.nn.functional.scaled_dot_product_attention is not supported"
65-
)
66-
logger.warning(f"{warning_msg}, fallback to '{actual_implementation}' attention")
67-
return actual_implementation
68-
69-
def sdpa_attn(self, hidden_states, encoder_hidden_states, attn_mask=None):
70-
q = self.to_q(hidden_states)
71-
k = self.to_k(encoder_hidden_states)
72-
v = self.to_v(encoder_hidden_states)
73-
74-
q = rearrange(q, "b s (n d) -> b n s d", n=self.num_heads)
75-
k = rearrange(k, "b s (n d) -> b n s d", n=self.num_heads)
76-
v = rearrange(v, "b s (n d) -> b n s d", n=self.num_heads)
77-
78-
hidden_states = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=self.scale)
79-
hidden_states = rearrange(hidden_states, "b n s d -> b s (n d)", n=self.num_heads)
80-
hidden_states = hidden_states.to(q.dtype)
81-
hidden_states = self.to_out(hidden_states)
82-
return hidden_states
83-
84-
def xformers_attn(self, hidden_states, encoder_hidden_states, attn_mask=None):
85-
import xformers.ops as xops
86-
87-
q = self.to_q(hidden_states)
88-
k = self.to_k(encoder_hidden_states)
89-
v = self.to_v(encoder_hidden_states)
90-
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
91-
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
92-
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
93-
94-
hidden_states = xops.memory_efficient_attention(q, k, v, attn_bias=attn_mask, scale=self.scale)
95-
hidden_states = rearrange(hidden_states, "b s n d -> b s (n d)")
96-
hidden_states = hidden_states.to(q.dtype)
97-
hidden_states = self.to_out(hidden_states)
98-
return hidden_states
99-
100-
def eager_attn(self, hidden_states, encoder_hidden_states, attn_mask=None):
101-
q = self.to_q(hidden_states)
102-
k = self.to_k(encoder_hidden_states)
103-
v = self.to_v(encoder_hidden_states)
104-
q = rearrange(q, "b s (n d) -> b n s d", n=self.num_heads)
105-
k = rearrange(k, "b s (n d) -> b n s d", n=self.num_heads)
106-
v = rearrange(v, "b s (n d) -> b n s d", n=self.num_heads)
107-
108-
hidden_states = self._eager_attn(q, k, v, attn_bias=attn_mask, scale=self.scale)
109-
hidden_states = rearrange(hidden_states, "b n s d -> b s (n d)", n=self.num_heads)
110-
hidden_states = hidden_states.to(q.dtype)
111-
hidden_states = self.to_out(hidden_states)
112-
return hidden_states
113-
114-
@staticmethod
115-
def _eager_attn(query, key, value, attn_bias=None, scale=None):
116-
scale = 1 / query.shape[-1] ** 0.5 if scale is None else scale
117-
query = query * scale
118-
attn = torch.matmul(query, key.transpose(-2, -1))
119-
if attn_bias is not None:
120-
attn = attn + attn_bias
121-
attn = attn.softmax(-1)
122-
return attn @ value
123140

124141
def forward(
125142
self,
126-
hidden_states,
127-
encoder_hidden_states=None,
128-
attn_mask=None,
143+
x: torch.Tensor,
144+
y: Optional[torch.Tensor] = None,
145+
attn_mask: Optional[torch.Tensor] = None,
129146
):
130-
if encoder_hidden_states is None:
131-
encoder_hidden_states = hidden_states
132-
133-
if self.attn_implementation == "xformers":
134-
return self.xformers_attn(hidden_states, encoder_hidden_states, attn_mask)
135-
if self.attn_implementation == "sdpa":
136-
return self.sdpa_attn(hidden_states, encoder_hidden_states, attn_mask)
137-
return self.eager_attn(hidden_states, encoder_hidden_states, attn_mask)
147+
if y is None:
148+
y = x
149+
q = rearrange(self.to_q(x), "b s (n d) -> b s n d", n=self.num_heads)
150+
k = rearrange(self.to_k(y), "b s (n d) -> b s n d", n=self.num_heads)
151+
v = rearrange(self.to_v(y), "b s (n d) -> b s n d", n=self.num_heads)
152+
out = attention(q, k, v, attn_mask=attn_mask, attn_impl=self.attn_impl, scale=self.scale)
153+
out = rearrange(out, "b s n d -> b s (n d)", n=self.num_heads)
154+
return self.to_out(out)

diffsynth_engine/models/basic/unet_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ def __init__(
5151
def forward(self, hidden_states, encoder_hidden_states):
5252
# 1. Self-Attention
5353
norm_hidden_states = self.norm1(hidden_states)
54-
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
54+
attn_output = self.attn1(norm_hidden_states)
5555
hidden_states = attn_output + hidden_states
5656

5757
# 2. Cross-Attention
5858
norm_hidden_states = self.norm2(hidden_states)
59-
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
59+
attn_output = self.attn2(norm_hidden_states, y=encoder_hidden_states)
6060
hidden_states = attn_output + hidden_states
6161

6262
# 3. Feed-forward

diffsynth_engine/models/components/vae.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def __init__(
8686
bias_q=True,
8787
bias_kv=True,
8888
bias_out=True,
89-
attn_implementation="xformers",
9089
device=device,
9190
dtype=dtype,
9291
)

0 commit comments

Comments
 (0)