|
7 | 7 | import math |
8 | 8 | from typing import Dict, Optional, Tuple |
9 | 9 |
|
10 | | -from .symmetric_patchifier import SymmetricPatchifier |
| 10 | +from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords |
11 | 11 |
|
12 | 12 |
|
13 | 13 | def get_timestep_embedding( |
@@ -377,12 +377,16 @@ def __init__(self, |
377 | 377 |
|
378 | 378 | positional_embedding_theta=10000.0, |
379 | 379 | positional_embedding_max_pos=[20, 2048, 2048], |
| 380 | + causal_temporal_positioning=False, |
| 381 | + vae_scale_factors=(8, 32, 32), |
380 | 382 | dtype=None, device=None, operations=None, **kwargs): |
381 | 383 | super().__init__() |
382 | 384 | self.generator = None |
| 385 | + self.vae_scale_factors = vae_scale_factors |
383 | 386 | self.dtype = dtype |
384 | 387 | self.out_channels = in_channels |
385 | 388 | self.inner_dim = num_attention_heads * attention_head_dim |
| 389 | + self.causal_temporal_positioning = causal_temporal_positioning |
386 | 390 |
|
387 | 391 | self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) |
388 | 392 |
|
@@ -416,50 +420,31 @@ def __init__(self, |
416 | 420 |
|
417 | 421 | self.patchifier = SymmetricPatchifier(1) |
418 | 422 |
|
419 | | - def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs): |
| 423 | + def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): |
420 | 424 | patches_replace = transformer_options.get("patches_replace", {}) |
421 | 425 |
|
422 | | - indices_grid = self.patchifier.get_grid( |
423 | | - orig_num_frames=x.shape[2], |
424 | | - orig_height=x.shape[3], |
425 | | - orig_width=x.shape[4], |
426 | | - batch_size=x.shape[0], |
427 | | - scale_grid=((1 / frame_rate) * 8, 32, 32), |
428 | | - device=x.device, |
429 | | - ) |
430 | | - |
431 | | - if guiding_latent is not None: |
432 | | - ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype) |
433 | | - input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1)) |
434 | | - ts *= input_ts |
435 | | - ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2) |
436 | | - timestep = self.patchifier.patchify(ts) |
437 | | - input_x = x.clone() |
438 | | - x[:, :, 0] = guiding_latent[:, :, 0] |
439 | | - if guiding_latent_noise_scale > 0: |
440 | | - if self.generator is None: |
441 | | - self.generator = torch.Generator(device=x.device).manual_seed(42) |
442 | | - elif self.generator.device != x.device: |
443 | | - self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state()) |
444 | | - |
445 | | - noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]] |
446 | | - scale = guiding_latent_noise_scale * (input_ts ** 2) |
447 | | - guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator) |
448 | | - |
449 | | - x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0]) |
| 426 | + orig_shape = list(x.shape) |
450 | 427 |
|
| 428 | + x, latent_coords = self.patchifier.patchify(x) |
| 429 | + pixel_coords = latent_to_pixel_coords( |
| 430 | + latent_coords=latent_coords, |
| 431 | + scale_factors=self.vae_scale_factors, |
| 432 | + causal_fix=self.causal_temporal_positioning, |
| 433 | + ) |
451 | 434 |
|
452 | | - orig_shape = list(x.shape) |
| 435 | + if keyframe_idxs is not None: |
| 436 | + pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs |
453 | 437 |
|
454 | | - x = self.patchifier.patchify(x) |
| 438 | + fractional_coords = pixel_coords.to(torch.float32) |
| 439 | + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) |
455 | 440 |
|
456 | 441 | x = self.patchify_proj(x) |
457 | 442 | timestep = timestep * 1000.0 |
458 | 443 |
|
459 | 444 | if attention_mask is not None and not torch.is_floating_point(attention_mask): |
460 | 445 | attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max |
461 | 446 |
|
462 | | - pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype) |
| 447 | + pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) |
463 | 448 |
|
464 | 449 | batch_size = x.shape[0] |
465 | 450 | timestep, embedded_timestep = self.adaln_single( |
@@ -519,8 +504,4 @@ def block_wrap(args): |
519 | 504 | out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), |
520 | 505 | ) |
521 | 506 |
|
522 | | - if guiding_latent is not None: |
523 | | - x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0] |
524 | | - |
525 | | - # print("res", x) |
526 | 507 | return x |
0 commit comments