Skip to content

Commit db9b295

Browse files
authored
Windows support for T2V scripts (#48)
1 parent 6c00cf0 commit db9b295

File tree

3 files changed

+46
-30
lines changed

3 files changed

+46
-30
lines changed

training/cogvideox_image_to_video_lora.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,11 @@ def log_validation(
202202

203203

204204
class CollateFunction:
205-
def __init__(self, weight_dtype, load_tensors):
205+
def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
206206
self.weight_dtype = weight_dtype
207207
self.load_tensors = load_tensors
208208

209-
def __call__(self, data):
209+
def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
210210
prompts = [x["prompt"] for x in data[0]]
211211

212212
if self.load_tensors:
@@ -519,13 +519,13 @@ def load_model_hook(models, input_dir):
519519
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
520520
)
521521

522-
collate_fn_instance = CollateFunction(weight_dtype, args.load_tensors)
522+
collate_fn = CollateFunction(weight_dtype, args.load_tensors)
523523

524524
train_dataloader = DataLoader(
525525
train_dataset,
526526
batch_size=1,
527527
sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
528-
collate_fn=collate_fn_instance,
528+
collate_fn=collate_fn,
529529
num_workers=args.dataloader_num_workers,
530530
pin_memory=args.pin_memory,
531531
)

training/cogvideox_text_to_video_lora.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,26 @@ def log_validation(
198198
return videos
199199

200200

201+
class CollateFunction:
202+
def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
203+
self.weight_dtype = weight_dtype
204+
self.load_tensors = load_tensors
205+
206+
def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
207+
prompts = [x["prompt"] for x in data[0]]
208+
209+
if self.load_tensors:
210+
prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
211+
212+
videos = [x["video"] for x in data[0]]
213+
videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
214+
215+
return {
216+
"videos": videos,
217+
"prompts": prompts,
218+
}
219+
220+
201221
def main(args):
202222
if args.report_to == "wandb" and args.hub_token is not None:
203223
raise ValueError(
@@ -491,19 +511,7 @@ def load_model_hook(models, input_dir):
491511
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
492512
)
493513

494-
def collate_fn(data):
495-
prompts = [x["prompt"] for x in data[0]]
496-
497-
if args.load_tensors:
498-
prompts = torch.stack(prompts).to(dtype=weight_dtype, non_blocking=True)
499-
500-
videos = [x["video"] for x in data[0]]
501-
videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True)
502-
503-
return {
504-
"videos": videos,
505-
"prompts": prompts,
506-
}
514+
collate_fn = CollateFunction(weight_dtype, args.load_tensors)
507515

508516
train_dataloader = DataLoader(
509517
train_dataset,

training/cogvideox_text_to_video_sft.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,26 @@ def log_validation(
188188
return videos
189189

190190

191+
class CollateFunction:
192+
def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
193+
self.weight_dtype = weight_dtype
194+
self.load_tensors = load_tensors
195+
196+
def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
197+
prompts = [x["prompt"] for x in data[0]]
198+
199+
if self.load_tensors:
200+
prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
201+
202+
videos = [x["video"] for x in data[0]]
203+
videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
204+
205+
return {
206+
"videos": videos,
207+
"prompts": prompts,
208+
}
209+
210+
191211
def main(args):
192212
if args.report_to == "wandb" and args.hub_token is not None:
193213
raise ValueError(
@@ -457,19 +477,7 @@ def load_model_hook(models, input_dir):
457477
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
458478
)
459479

460-
def collate_fn(data):
461-
prompts = [x["prompt"] for x in data[0]]
462-
463-
if args.load_tensors:
464-
prompts = torch.stack(prompts).to(dtype=weight_dtype, non_blocking=True)
465-
466-
videos = [x["video"] for x in data[0]]
467-
videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True)
468-
469-
return {
470-
"videos": videos,
471-
"prompts": prompts,
472-
}
480+
collate_fn = CollateFunction(weight_dtype, args.load_tensors)
473481

474482
train_dataloader = DataLoader(
475483
train_dataset,

0 commit comments

Comments
 (0)