Skip to content

Commit 606b642

Browse files
committed
[WIP][Fix] order of sigmas, dataloader issue to debug
1 parent 6f2c09d commit 606b642

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

fastvideo/distill/solver.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,16 @@ class EulerSolver:
242242

243243
def __init__(self, sigmas, timesteps=1000, euler_timesteps=50):
244244
self.step_ratio = timesteps // euler_timesteps
245+
245246
self.euler_timesteps = (np.arange(1, euler_timesteps + 1) * self.step_ratio).round().astype(np.int64) - 1
246247
self.euler_timesteps_prev = np.asarray([0] + self.euler_timesteps[:-1].tolist())
247248
self.sigmas = sigmas[self.euler_timesteps]
248249
self.sigmas_prev = np.asarray([sigmas[0]] +
249250
sigmas[self.euler_timesteps[:-1]].tolist()) # either use sigma0 or 0
251+
print(f"sigmas: {sigmas}")
252+
print(f"euler_timesteps: {self.euler_timesteps}")
253+
print(f"sigmas: {self.sigmas}")
254+
print(f"sigmas_prev: {self.sigmas_prev}")
250255

251256
self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long()
252257
self.euler_timesteps_prev = torch.from_numpy(self.euler_timesteps_prev).long()

fastvideo/v1/training/distillation_pipeline.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs):
150150
sigmas = noise_scheduler.sigmas
151151

152152
self.solver = EulerSolver(
153-
sigmas.numpy()[::-1],
153+
sigmas.numpy(),
154154
noise_scheduler.config.num_train_timesteps,
155155
euler_timesteps=fastvideo_args.num_euler_timesteps,
156156
)
@@ -324,7 +324,19 @@ def log_validation(self, transformer, fastvideo_args, global_step):
324324
x = torchvision.utils.make_grid(x, nrow=6)
325325
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
326326
frames.append((x * 255).numpy().astype(np.uint8))
327-
videos.append(frames)
327+
# videos.append(frames)
328+
videos = [frames]
329+
330+
video_filenames = []
331+
video_captions = []
332+
for i, video in enumerate(videos):
333+
caption = captions[i]
334+
filename = os.path.join(
335+
fastvideo_args.output_dir,
336+
f"validation_step_{global_step}_video_{i}.mp4")
337+
imageio.mimsave(filename, video, fps=sampling_param.fps)
338+
video_filenames.append(filename)
339+
video_captions.append(caption)
328340

329341
# Log validation results
330342
if self.rank == 0:

0 commit comments

Comments
 (0)