Skip to content

Commit 8da9638

Browse files
authored
more dataset fixes from stashed changes (#49)
1 parent f2a1626 commit 8da9638

File tree

2 files changed

+37
-40
lines changed

2 files changed

+37
-40
lines changed

training/dataset.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -229,32 +229,38 @@ def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tenso
229229
pt_filename = f"{filename_without_ext}.pt"
230230

231231
# The current path is something like: /a/b/c/d/videos/00001.mp4
232-
# We need to reach: /a/b/c/d/latents/00001.pt
233-
images_path = path.parent.parent.joinpath("image_latents")
234-
latents_path = path.parent.parent.joinpath("latents")
235-
embeds_path = path.parent.parent.joinpath("embeddings")
236-
237-
if not latents_path.exists() or not embeds_path.exists() or (self.image_to_video and not images_path.exists()):
232+
# We need to reach: /a/b/c/d/video_latents/00001.pt
233+
image_latents_path = path.parent.parent.joinpath("image_latents")
234+
video_latents_path = path.parent.parent.joinpath("video_latents")
235+
embeds_path = path.parent.parent.joinpath("prompt_embeds")
236+
237+
if (
238+
not video_latents_path.exists()
239+
or not embeds_path.exists()
240+
or (self.image_to_video and not image_latents_path.exists())
241+
):
238242
raise ValueError(
239-
f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains two folders named `latents` and `embeddings`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present."
243+
f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains two folders named `video_latents` and `prompt_embeds`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present."
240244
)
241245

242246
if self.image_to_video:
243-
image_filepath = images_path.joinpath(pt_filename)
244-
latent_filepath = latents_path.joinpath(pt_filename)
247+
image_latent_filepath = image_latents_path.joinpath(pt_filename)
248+
video_latent_filepath = video_latents_path.joinpath(pt_filename)
245249
embeds_filepath = embeds_path.joinpath(pt_filename)
246250

247-
if not latent_filepath.is_file() or not embeds_filepath.is_file():
251+
if not video_latent_filepath.is_file() or not embeds_filepath.is_file():
248252
if self.image_to_video:
249-
image_filepath = image_filepath.as_posix()
250-
latent_filepath = latent_filepath.as_posix()
253+
image_latent_filepath = image_latent_filepath.as_posix()
254+
video_latent_filepath = video_latent_filepath.as_posix()
251255
embeds_filepath = embeds_filepath.as_posix()
252256
raise ValueError(
253-
f"The file {latent_filepath=} or {embeds_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`."
257+
f"The file {video_latent_filepath=} or {embeds_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`."
254258
)
255259

256-
images = torch.load(image_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None
257-
latents = torch.load(latent_filepath, map_location="cpu", weights_only=True)
260+
images = (
261+
torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None
262+
)
263+
latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True)
258264
embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True)
259265

260266
return images, latents, embeds

training/prepare_dataset.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -477,8 +477,6 @@ def collate_fn(data):
477477
# 3. Prepare models
478478
device = f"cuda:{rank}"
479479

480-
generator = torch.Generator(device).manual_seed(args.seed)
481-
482480
if args.save_latents_and_embeddings:
483481
tokenizer = T5Tokenizer.from_pretrained(args.model_id, subfolder="tokenizer")
484482
text_encoder = T5EncoderModel.from_pretrained(
@@ -520,29 +518,22 @@ def collate_fn(data):
520518

521519
# Encode videos & images
522520
if args.save_latents_and_embeddings:
523-
if args.save_image_latents:
524-
image_noise_sigma = torch.normal(
525-
mean=-3.0,
526-
std=0.5,
527-
size=(images.size(0),),
528-
generator=generator,
529-
device=device,
530-
dtype=weight_dtype,
531-
)
532-
image_noise_sigma = torch.exp(image_noise_sigma)
533-
noisy_images = (
534-
images
535-
+ torch.empty_like(images).normal_(generator=generator)
536-
* image_noise_sigma[:, None, None, None, None]
537-
)
538-
image_latent_dist = vae.encode(noisy_images).latent_dist
539-
image_latents = image_latent_dist.sample() * vae.config.scaling_factor
540-
image_latents = image_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
541-
image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
542-
543-
latent_dist = vae.encode(videos).latent_dist
544-
video_latents = latent_dist.sample(generator=generator) * vae.config.scaling_factor
545-
video_latents = video_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
521+
if args.use_slicing:
522+
if args.save_image_latents:
523+
encoded_slices = [vae._encode(image_slice) for image_slice in images.split(1)]
524+
image_latents = torch.cat(encoded_slices)
525+
image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
526+
527+
encoded_slices = [vae._encode(video_slice) for video_slice in videos.split(1)]
528+
video_latents = torch.cat(encoded_slices)
529+
530+
else:
531+
if args.save_image_latents:
532+
image_latents = vae._encode(images)
533+
image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
534+
535+
video_latents = vae._encode(videos)
536+
546537
video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
547538

548539
# Encode prompts

0 commit comments

Comments
 (0)