Skip to content

Commit feb2e26

Browse files
authored
Improve dataset preparation support + multiresolution prep (#39)
* update * make style * renormalize correctly * apply suggestions from review * apply suggestions from review * update
1 parent a6c246c commit feb2e26

File tree

5 files changed

+419
-355
lines changed

5 files changed

+419
-355
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ from diffusers import export_to_video
4646
pipe = CogVideoXPipeline.from_pretrained(
4747
"THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16
4848
).to("cuda")
49-
+ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name=["cogvideox-lora"])
49+
+ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="cogvideox-lora")
5050
+ pipe.set_adapters(["cogvideox-lora"], [1.0])
5151

5252
video = pipe("<my-awesome-prompt>").frames[0]
@@ -429,7 +429,7 @@ With `train_batch_size = 4`:
429429
- [ ] Make scripts compatible with FSDP
430430
- [x] Make scripts compatible with DeepSpeed
431431
- [ ] vLLM-powered captioning script
432-
- [ ] Multi-resolution/frame support in `prepare_dataset.py`
432+
- [x] Multi-resolution/frame support in `prepare_dataset.py`
433433
- [ ] Analyzing traces for potential speedups and removing as many syncs as possible
434434
- [ ] Support for QLoRA (priority), and other types of high usage LoRAs methods
435435
- [x] Test scripts with memory-efficient optimizer from bitsandbytes

README_zh.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ diffusers(该分支为 CogVideoX 的图像到视频添加了 LoRA 加载支持
440440
- [ ] 使脚本兼容 FSDP
441441
- [x] 使脚本兼容 DeepSpeed
442442
- [ ] 基于 vLLM 的字幕脚本
443-
- [ ]`prepare_dataset.py` 中支持多分辨率/帧数
443+
- [x]`prepare_dataset.py` 中支持多分辨率/帧数
444444
- [ ] 分析性能瓶颈并尽可能减少同步操作
445445
- [ ] 支持 QLoRA(优先),以及其他高使用率的 LoRA 方法
446446
- [x] 使用 bitsandbytes 的节省内存优化器测试脚本

training/cogvideox_image_to_video_lora.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def log_validation(
200200

201201
return videos
202202

203+
203204
class CollateFunction:
204205
def __init__(self, weight_dtype, load_tensors):
205206
self.weight_dtype = weight_dtype
@@ -223,6 +224,7 @@ def __call__(self, data):
223224
"prompts": prompts,
224225
}
225226

227+
226228
def main(args):
227229
if args.report_to == "wandb" and args.hub_token is not None:
228230
raise ValueError(
@@ -647,7 +649,7 @@ def load_model_hook(models, input_dir):
647649

648650
# Encode videos
649651
if not args.load_tensors:
650-
images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
652+
images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
651653
image_noise_sigma = torch.normal(
652654
mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype
653655
)

training/dataset.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,16 @@ def __init__(
7878
self.video_paths,
7979
) = self._load_dataset_from_csv()
8080

81-
self.num_videos = len(self.video_paths)
82-
if self.num_videos != len(self.prompts):
81+
if len(self.video_paths) != len(self.prompts):
8382
raise ValueError(
8483
f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
8584
)
8685

8786
self.video_transforms = transforms.Compose(
8887
[
89-
transforms.RandomHorizontalFlip(random_flip) if random_flip else transforms.Lambda(self.identity_transform),
88+
transforms.RandomHorizontalFlip(random_flip)
89+
if random_flip
90+
else transforms.Lambda(self.identity_transform),
9091
transforms.Lambda(self.scale_transform),
9192
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
9293
]
@@ -101,7 +102,7 @@ def scale_transform(x):
101102
return x / 255.0
102103

103104
def __len__(self) -> int:
104-
return self.num_videos
105+
return len(self.video_paths)
105106

106107
def __getitem__(self, index: int) -> Dict[str, Any]:
107108
if isinstance(index, list):
@@ -358,10 +359,30 @@ def _find_nearest_resolution(self, height, width):
358359

359360

360361
class BucketSampler(Sampler):
361-
def __init__(self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True) -> None:
362+
r"""
363+
PyTorch Sampler that groups 3D data by height, width and frames.
364+
365+
Args:
366+
data_source (`VideoDataset`):
367+
A PyTorch dataset object that is an instance of `VideoDataset`.
368+
batch_size (`int`, defaults to `8`):
369+
The batch size to use for training.
370+
shuffle (`bool`, defaults to `True`):
371+
Whether or not to shuffle the data in each batch before dispatching to dataloader.
372+
drop_last (`bool`, defaults to `False`):
373+
Whether or not to drop incomplete buckets of data after completely iterating over all data
374+
in the dataset. If set to True, only batches that have `batch_size` number of entries will
375+
be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
376+
and batches that do not have `batch_size` number of entries will also be yielded.
377+
"""
378+
379+
def __init__(
380+
self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
381+
) -> None:
362382
self.data_source = data_source
363383
self.batch_size = batch_size
364384
self.shuffle = shuffle
385+
self.drop_last = drop_last
365386

366387
self.buckets = {resolution: [] for resolution in data_source.resolutions}
367388

@@ -377,3 +398,15 @@ def __iter__(self):
377398
yield self.buckets[(f, h, w)]
378399
del self.buckets[(f, h, w)]
379400
self.buckets[(f, h, w)] = []
401+
402+
if self.drop_last:
403+
return
404+
405+
for fhw, bucket in list(self.buckets.items()):
406+
if len(bucket) == 0:
407+
continue
408+
if self.shuffle:
409+
random.shuffle(bucket)
410+
yield bucket
411+
del self.buckets[fhw]
412+
self.buckets[fhw] = []

0 commit comments

Comments
 (0)