Skip to content

Commit d8bd10e

Browse files
committed
up
1 parent 16c1467 commit d8bd10e

File tree

4 files changed

+280
-246
lines changed

4 files changed

+280
-246
lines changed

src/diffusers/models/normalization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,8 @@ def forward(self, hidden_states):
550550
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
551551
if self.bias is not None:
552552
hidden_states = hidden_states + self.bias
553-
elif is_torch_version(">=", "2.4"):
553+
# YiYi TODO: testing only, remove this change before merging
554+
elif is_torch_version(">=", "3.3"):
554555
if self.weight is not None:
555556
# convert into half-precision if necessary
556557
if self.weight.dtype in [torch.float16, torch.bfloat16]:

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -115,46 +115,63 @@ def __init__(
115115
self.theta = theta
116116
self._causal_rope_fix = _causal_rope_fix
117117

118-
def forward(
119-
self,
120-
hidden_states: torch.Tensor,
121-
num_frames: int,
122-
height: int,
123-
width: int,
124-
frame_rate: Optional[int] = None,
125-
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
126-
) -> Tuple[torch.Tensor, torch.Tensor]:
127-
batch_size = hidden_states.size(0)
128-
118+
119+
def _prepare_video_coords(self, batch_size: int, num_frames: int, height: int, width: int, rope_interpolation_scale: Tuple[torch.Tensor, float, float], device: torch.device) -> torch.Tensor:
129120
# Always compute rope in fp32
130-
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
131-
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
132-
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
121+
grid_h = torch.arange(height, dtype=torch.float32, device=device)
122+
grid_w = torch.arange(width, dtype=torch.float32, device=device)
123+
grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
133124
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
134125
grid = torch.stack(grid, dim=0)
135126
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
136127

137-
if rope_interpolation_scale is not None:
138-
if isinstance(rope_interpolation_scale, tuple):
139-
# This will be deprecated in v0.34.0
140-
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
141-
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
142-
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
128+
if isinstance(rope_interpolation_scale, tuple):
129+
# This will be deprecated in v0.34.0
130+
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
131+
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
132+
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
133+
else:
134+
if not self._causal_rope_fix:
135+
grid[:, 0:1] = (
136+
grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames
137+
)
143138
else:
144-
if not self._causal_rope_fix:
145-
grid[:, 0:1] = (
146-
grid[:, 0:1] * rope_interpolation_scale[0:1] * self.patch_size_t / self.base_num_frames
147-
)
148-
else:
149-
grid[:, 0:1] = (
150-
((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0)
151-
* self.patch_size_t
152-
/ self.base_num_frames
153-
)
154-
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height
155-
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width
139+
grid[:, 0:1] = (
140+
((grid[:, 0:1] - 1) * rope_interpolation_scale[0:1] + 1 / frame_rate).clamp(min=0)
141+
* self.patch_size_t
142+
/ self.base_num_frames
143+
)
144+
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1:2] * self.patch_size / self.base_height
145+
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2:3] * self.patch_size / self.base_width
156146

157147
grid = grid.flatten(2, 4).transpose(1, 2)
148+
149+
return grid
150+
151+
152+
def forward(
153+
self,
154+
hidden_states: torch.Tensor,
155+
num_frames: Optional[int] = None,
156+
height: Optional[int] = None,
157+
width: Optional[int] = None,
158+
frame_rate: Optional[int] = None,
159+
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
160+
video_coords: Optional[torch.Tensor] = None,
161+
) -> Tuple[torch.Tensor, torch.Tensor]:
162+
batch_size = hidden_states.size(0)
163+
164+
if video_coords is None:
165+
grid = self._prepare_video_coords(batch_size, num_frames, height, width, rope_interpolation_scale=rope_interpolation_scale, device=hidden_states.device)
166+
else:
167+
grid = torch.stack(
168+
[
169+
video_coords[:, 0] / self.base_num_frames,
170+
video_coords[:, 1] / self.base_height,
171+
video_coords[:, 2] / self.base_width
172+
],
173+
dim=-1,
174+
)
158175

159176
start = 1.0
160177
end = self.theta
@@ -387,11 +404,12 @@ def forward(
387404
encoder_hidden_states: torch.Tensor,
388405
timestep: torch.LongTensor,
389406
encoder_attention_mask: torch.Tensor,
390-
num_frames: int,
391-
height: int,
392-
width: int,
393-
frame_rate: int,
407+
num_frames: Optional[int] = None,
408+
height: Optional[int] = None,
409+
width: Optional[int] = None,
410+
frame_rate: Optional[int] = None,
394411
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
412+
video_coords: Optional[torch.Tensor] = None,
395413
attention_kwargs: Optional[Dict[str, Any]] = None,
396414
return_dict: bool = True,
397415
) -> torch.Tensor:
@@ -414,7 +432,8 @@ def forward(
414432
msg = "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0."
415433
deprecate("rope_interpolation_scale", "0.34.0", msg)
416434

417-
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale)
435+
436+
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, frame_rate, rope_interpolation_scale, video_coords)
418437

419438
# convert encoder_attention_mask to a bias the same way we do for attention_mask
420439
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
@@ -475,5 +494,6 @@ def apply_rotary_emb(x, freqs):
475494
cos, sin = freqs
476495
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
477496
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
478-
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
497+
# YiYi TODO: testing only, remove this change before merging
498+
out = (x * cos.to(x.dtype) + x_rotated * sin.to(x.dtype)).to(x.dtype)
479499
return out

0 commit comments

Comments
 (0)