Skip to content

Commit 4f2744e

Browse files
authored
Update for windows compability (#32)
* replaced lambda statements with local functions to work with pickle in dataset.py and moved collate function outside of main and added its own class in cogvideox_image_to_video_lora.py * Update cogvideox_image_to_video_lora.py Remove the bug-fix related to image encoding since it's already present in #31 * Update cogvideox_image_to_video_lora.py Revert to last commit (60c4682)
1 parent 3754d20 commit 4f2744e

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

training/cogvideox_image_to_video_lora.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,28 @@ def log_validation(
200200

201201
return videos
202202

203+
class CollateFunction:
204+
def __init__(self, weight_dtype, load_tensors):
205+
self.weight_dtype = weight_dtype
206+
self.load_tensors = load_tensors
207+
208+
def __call__(self, data):
209+
prompts = [x["prompt"] for x in data[0]]
210+
211+
if self.load_tensors:
212+
prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
213+
214+
images = [x["image"] for x in data[0]]
215+
images = torch.stack(images).to(dtype=self.weight_dtype, non_blocking=True)
216+
217+
videos = [x["video"] for x in data[0]]
218+
videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
219+
220+
return {
221+
"images": images,
222+
"videos": videos,
223+
"prompts": prompts,
224+
}
203225

204226
def main(args):
205227
if args.report_to == "wandb" and args.hub_token is not None:
@@ -486,29 +508,13 @@ def load_model_hook(models, input_dir):
486508
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
487509
)
488510

489-
def collate_fn(data):
490-
prompts = [x["prompt"] for x in data[0]]
491-
492-
if args.load_tensors:
493-
prompts = torch.stack(prompts).to(dtype=weight_dtype, non_blocking=True)
494-
495-
images = [x["image"] for x in data[0]]
496-
images = torch.stack(images).to(dtype=weight_dtype, non_blocking=True)
497-
498-
videos = [x["video"] for x in data[0]]
499-
videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True)
500-
501-
return {
502-
"images": images,
503-
"videos": videos,
504-
"prompts": prompts,
505-
}
511+
collate_fn_instance = CollateFunction(weight_dtype, args.load_tensors)
506512

507513
train_dataloader = DataLoader(
508514
train_dataset,
509515
batch_size=1,
510516
sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
511-
collate_fn=collate_fn,
517+
collate_fn=collate_fn_instance,
512518
num_workers=args.dataloader_num_workers,
513519
pin_memory=args.pin_memory,
514520
)
@@ -641,6 +647,7 @@ def collate_fn(data):
641647

642648
# Encode videos
643649
if not args.load_tensors:
650+
images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
644651
image_noise_sigma = torch.normal(
645652
mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype
646653
)

training/dataset.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,20 @@ def __init__(
8686

8787
self.video_transforms = transforms.Compose(
8888
[
89-
transforms.RandomHorizontalFlip(random_flip) if random_flip else transforms.Lambda(lambda x: x),
90-
transforms.Lambda(lambda x: x / 255.0),
89+
transforms.RandomHorizontalFlip(random_flip) if random_flip else transforms.Lambda(self.identity_transform),
90+
transforms.Lambda(self.scale_transform),
9191
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
9292
]
9393
)
9494

95+
@staticmethod
96+
def identity_transform(x):
97+
return x
98+
99+
@staticmethod
100+
def scale_transform(x):
101+
return x / 255.0
102+
95103
def __len__(self) -> int:
96104
return self.num_videos
97105

0 commit comments

Comments
 (0)