Skip to content
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions diffsynth/distributed/xdit_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def rope_apply(x, freqs, num_heads):
return x_out.to(x.dtype)

def usp_dit_forward(self,
x: torch.Tensor,
latents: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
encoder_hidden_states: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
Expand All @@ -52,80 +52,80 @@ def usp_dit_forward(self,
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
encoder_hidden_states = self.text_embedding(encoder_hidden_states)

if self.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
latents = torch.cat([latents, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
encoder_hidden_states = torch.cat([clip_embdding, encoder_hidden_states], dim=1)

x, (f, h, w) = self.patchify(x)
latents, (f, h, w) = self.patchify(latents)

freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
], dim=-1).reshape(f * h * w, 1, -1).to(latents.device)

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward

# Context Parallel
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
chunks = torch.chunk(latents, get_sequence_parallel_world_size(), dim=1)
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
x = chunks[get_sequence_parallel_rank()]
latents = chunks[get_sequence_parallel_rank()]

for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
latents = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
latents, encoder_hidden_states, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
latents = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
latents, encoder_hidden_states, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
latents = block(latents, encoder_hidden_states, t_mod, freqs)

x = self.head(x, t)
latents = self.head(latents, t)

# Context Parallel
x = get_sp_group().all_gather(x, dim=1)
x = x[:, :-pad_shape] if pad_shape > 0 else x
latents = get_sp_group().all_gather(latents, dim=1)
latents = latents[:, :-pad_shape] if pad_shape > 0 else latents

# unpatchify
x = self.unpatchify(x, (f, h, w))
return x
latents = self.unpatchify(latents, (f, h, w))
return latents


def usp_attn_forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
def usp_attn_forward(self, latents, freqs):
q = self.norm_q(self.q(latents))
k = self.norm_k(self.k(latents))
v = self.v(latents)

q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)

x = xFuserLongContextAttention()(
latents = xFuserLongContextAttention()(
None,
query=q,
key=k,
value=v,
)
x = x.flatten(2)
latents = latents.flatten(2)

del q, k, v
torch.cuda.empty_cache()
return self.o(x)
return self.o(latents)
52 changes: 26 additions & 26 deletions diffsynth/models/wan_video_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.gate = GateModule()

def forward(self, x, context, t_mod, freqs):
def forward(self, hidden_states, encoder_hidden_states, t_mod, freqs):
has_seq = len(t_mod.shape) == 4
chunk_dim = 2 if has_seq else 1
# msa: multi-head self-attention mlp: multi-layer perceptron
Expand All @@ -222,12 +222,12 @@ def forward(self, x, context, t_mod, freqs):
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
)
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
x = x + self.cross_attn(self.norm3(x), context)
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
return x
input_x = modulate(self.norm1(hidden_states), shift_msa, scale_msa)
hidden_states = self.gate(hidden_states, gate_msa, self.self_attn(input_x, freqs))
hidden_states = hidden_states + self.cross_attn(self.norm3(hidden_states), encoder_hidden_states)
input_x = modulate(self.norm2(hidden_states), shift_mlp, scale_mlp)
hidden_states = self.gate(hidden_states, gate_mlp, self.ffn(input_x))
return hidden_states


class MLP(torch.nn.Module):
Expand All @@ -244,10 +244,10 @@ def __init__(self, in_dim, out_dim, has_pos_emb=False):
if has_pos_emb:
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))

def forward(self, x):
def forward(self, hidden_states):
if self.has_pos_emb:
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
return self.proj(x)
hidden_states = hidden_states + self.emb_pos.to(dtype=hidden_states.dtype, device=hidden_states.device)
return self.proj(hidden_states)


class Head(nn.Module):
Expand All @@ -259,14 +259,14 @@ def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)

def forward(self, x, t_mod):
def forward(self, hidden_states, t_mod):
if len(t_mod.shape) == 3:
shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2)
x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)))
hidden_states = (self.head(self.norm(hidden_states) * (1 + scale.squeeze(2)) + shift.squeeze(2)))
else:
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + scale) + shift))
return x
hidden_states = (self.head(self.norm(hidden_states) * (1 + scale) + shift))
return hidden_states


class WanModel(torch.nn.Module):
Expand Down Expand Up @@ -354,9 +354,9 @@ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
)

def forward(self,
x: torch.Tensor,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
encoder_hidden_states: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
Expand All @@ -366,20 +366,20 @@ def forward(self,
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
context = self.text_embedding(encoder_hidden_states)

if self.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) # (b, c_x + c_y, f, h, w)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There seems to be a copy-paste error here. encoder_hidden_states is being concatenated with hidden_states, but based on the original code and the comment, it should be y (the reference image latents). Concatenating text embeddings with image latents along the channel dimension is likely incorrect.

Suggested change
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) # (b, c_x + c_y, f, h, w)
hidden_states = torch.cat([hidden_states, y], dim=1) # (b, c_x + c_y, f, h, w)

clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)

x, (f, h, w) = self.patchify(x)
hidden_states, (f, h, w) = self.patchify(hidden_states)

freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
], dim=-1).reshape(f * h * w, 1, -1).to(hidden_states.device)

def create_custom_forward(module):
def custom_forward(*inputs):
Expand All @@ -390,23 +390,23 @@ def custom_forward(*inputs):
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
Comment on lines +393 to 403
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable x is used here in the arguments to torch.utils.checkpoint.checkpoint, but it is not defined in the scope of this function after the refactoring. The variable was renamed to hidden_states. You should use hidden_states instead of x.

Suggested change
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, context, t_mod, freqs,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, context, t_mod, freqs,
use_reentrant=False,
)

else:
x = block(x, context, t_mod, freqs)
hidden_states = block(hidden_states, context, t_mod, freqs)

x = self.head(x, t)
x = self.unpatchify(x, (f, h, w))
return x
hidden_states = self.head(hidden_states, t)
hidden_states = self.unpatchify(hidden_states, (f, h, w))
return hidden_states

@staticmethod
def state_dict_converter():
Expand Down
Loading