Skip to content

Commit 93fedd9

Browse files
Support LTXV 0.9.5.
Credits: Lightricks team.
1 parent 745b136 commit 93fedd9

File tree

11 files changed

+659
-139
lines changed

11 files changed

+659
-139
lines changed

comfy/ldm/lightricks/model.py

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import math
88
from typing import Dict, Optional, Tuple
99

10-
from .symmetric_patchifier import SymmetricPatchifier
10+
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
1111

1212

1313
def get_timestep_embedding(
@@ -377,12 +377,16 @@ def __init__(self,
377377

378378
positional_embedding_theta=10000.0,
379379
positional_embedding_max_pos=[20, 2048, 2048],
380+
causal_temporal_positioning=False,
381+
vae_scale_factors=(8, 32, 32),
380382
dtype=None, device=None, operations=None, **kwargs):
381383
super().__init__()
382384
self.generator = None
385+
self.vae_scale_factors = vae_scale_factors
383386
self.dtype = dtype
384387
self.out_channels = in_channels
385388
self.inner_dim = num_attention_heads * attention_head_dim
389+
self.causal_temporal_positioning = causal_temporal_positioning
386390

387391
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
388392

@@ -416,50 +420,31 @@ def __init__(self,
416420

417421
self.patchifier = SymmetricPatchifier(1)
418422

419-
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
423+
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
420424
patches_replace = transformer_options.get("patches_replace", {})
421425

422-
indices_grid = self.patchifier.get_grid(
423-
orig_num_frames=x.shape[2],
424-
orig_height=x.shape[3],
425-
orig_width=x.shape[4],
426-
batch_size=x.shape[0],
427-
scale_grid=((1 / frame_rate) * 8, 32, 32),
428-
device=x.device,
429-
)
430-
431-
if guiding_latent is not None:
432-
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
433-
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
434-
ts *= input_ts
435-
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
436-
timestep = self.patchifier.patchify(ts)
437-
input_x = x.clone()
438-
x[:, :, 0] = guiding_latent[:, :, 0]
439-
if guiding_latent_noise_scale > 0:
440-
if self.generator is None:
441-
self.generator = torch.Generator(device=x.device).manual_seed(42)
442-
elif self.generator.device != x.device:
443-
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
444-
445-
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
446-
scale = guiding_latent_noise_scale * (input_ts ** 2)
447-
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
448-
449-
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
426+
orig_shape = list(x.shape)
450427

428+
x, latent_coords = self.patchifier.patchify(x)
429+
pixel_coords = latent_to_pixel_coords(
430+
latent_coords=latent_coords,
431+
scale_factors=self.vae_scale_factors,
432+
causal_fix=self.causal_temporal_positioning,
433+
)
451434

452-
orig_shape = list(x.shape)
435+
if keyframe_idxs is not None:
436+
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
453437

454-
x = self.patchifier.patchify(x)
438+
fractional_coords = pixel_coords.to(torch.float32)
439+
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
455440

456441
x = self.patchify_proj(x)
457442
timestep = timestep * 1000.0
458443

459444
if attention_mask is not None and not torch.is_floating_point(attention_mask):
460445
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
461446

462-
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
447+
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
463448

464449
batch_size = x.shape[0]
465450
timestep, embedded_timestep = self.adaln_single(
@@ -519,8 +504,4 @@ def block_wrap(args):
519504
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
520505
)
521506

522-
if guiding_latent is not None:
523-
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
524-
525-
# print("res", x)
526507
return x

comfy/ldm/lightricks/symmetric_patchifier.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,29 @@
66
from torch import Tensor
77

88

9-
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
10-
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
11-
dims_to_append = target_dims - x.ndim
12-
if dims_to_append < 0:
13-
raise ValueError(
14-
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
15-
)
16-
elif dims_to_append == 0:
17-
return x
18-
return x[(...,) + (None,) * dims_to_append]
9+
def latent_to_pixel_coords(
10+
latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False
11+
) -> Tensor:
12+
"""
13+
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
14+
configuration.
15+
Args:
16+
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
17+
containing the latent corner coordinates of each token.
18+
scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space.
19+
causal_fix (bool): Whether to take into account the different temporal scale
20+
of the first frame. Default = False for backwards compatibility.
21+
Returns:
22+
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
23+
"""
24+
pixel_coords = (
25+
latent_coords
26+
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
27+
)
28+
if causal_fix:
29+
# Fix temporal scale for first frame to 1 due to causality
30+
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
31+
return pixel_coords
1932

2033

2134
class Patchifier(ABC):
@@ -44,44 +57,43 @@ def unpatchify(
4457
def patch_size(self):
4558
return self._patch_size
4659

47-
def get_grid(
48-
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
60+
def get_latent_coords(
61+
self, latent_num_frames, latent_height, latent_width, batch_size, device
4962
):
50-
f = orig_num_frames // self._patch_size[0]
51-
h = orig_height // self._patch_size[1]
52-
w = orig_width // self._patch_size[2]
53-
grid_h = torch.arange(h, dtype=torch.float32, device=device)
54-
grid_w = torch.arange(w, dtype=torch.float32, device=device)
55-
grid_f = torch.arange(f, dtype=torch.float32, device=device)
56-
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
57-
grid = torch.stack(grid, dim=0)
58-
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
59-
60-
if scale_grid is not None:
61-
for i in range(3):
62-
if isinstance(scale_grid[i], Tensor):
63-
scale = append_dims(scale_grid[i], grid.ndim - 1)
64-
else:
65-
scale = scale_grid[i]
66-
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
67-
68-
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
69-
return grid
63+
"""
64+
Return a tensor of shape [batch_size, 3, num_patches] containing the
65+
top-left corner latent coordinates of each latent patch.
66+
The tensor is repeated for each batch element.
67+
"""
68+
latent_sample_coords = torch.meshgrid(
69+
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
70+
torch.arange(0, latent_height, self._patch_size[1], device=device),
71+
torch.arange(0, latent_width, self._patch_size[2], device=device),
72+
indexing="ij",
73+
)
74+
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
75+
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
76+
latent_coords = rearrange(
77+
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
78+
)
79+
return latent_coords
7080

7181

7282
class SymmetricPatchifier(Patchifier):
7383
def patchify(
7484
self,
7585
latents: Tensor,
7686
) -> Tuple[Tensor, Tensor]:
87+
b, _, f, h, w = latents.shape
88+
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
7789
latents = rearrange(
7890
latents,
7991
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
8092
p1=self._patch_size[0],
8193
p2=self._patch_size[1],
8294
p3=self._patch_size[2],
8395
)
84-
return latents
96+
return latents, latent_coords
8597

8698
def unpatchify(
8799
self,

comfy/ldm/lightricks/vae/causal_conv3d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def __init__(
1515
stride: Union[int, Tuple[int]] = 1,
1616
dilation: int = 1,
1717
groups: int = 1,
18+
spatial_padding_mode: str = "zeros",
1819
**kwargs,
1920
):
2021
super().__init__()
@@ -38,7 +39,7 @@ def __init__(
3839
stride=stride,
3940
dilation=dilation,
4041
padding=padding,
41-
padding_mode="zeros",
42+
padding_mode=spatial_padding_mode,
4243
groups=groups,
4344
)
4445

0 commit comments

Comments
 (0)