@@ -620,7 +620,7 @@ def _preprocess_data(self):
620620
621621 # Training transforms
622622 frames = (frames - 127.5 ) / 127.5
623- frames = frames .permute (0 , 3 , 1 , 2 ) # [F, C, H, W]
623+ frames = frames .permute (0 , 3 , 1 , 2 ) # [F, C, H, W]
624624 progress_dataset_bar .set_description (
625625 f"Loading progress Resizing video from { frames .shape [2 ]} x{ frames .shape [3 ]} to { self .height } x{ self .width } "
626626 )
@@ -1234,7 +1234,6 @@ def load_model_hook(models, input_dir):
12341234 id_token = args .id_token ,
12351235 )
12361236
1237-
12381237 def encode_video (video , bar ):
12391238 bar .update (1 )
12401239 video = video .to (accelerator .device , dtype = vae .dtype ).unsqueeze (0 )
@@ -1246,7 +1245,9 @@ def encode_video(video, bar):
12461245 range (0 , len (train_dataset .instance_videos )),
12471246 desc = "Loading Encode videos" ,
12481247 )
1249- train_dataset .instance_videos = [encode_video (video ,progress_encode_bar ) for video in train_dataset .instance_videos ]
1248+ train_dataset .instance_videos = [
1249+ encode_video (video , progress_encode_bar ) for video in train_dataset .instance_videos
1250+ ]
12501251 progress_encode_bar .close ()
12511252
12521253 def collate_fn (examples ):
0 commit comments