diff --git a/flashvideo/flow_video.py b/flashvideo/flow_video.py index 3371107..12e48ab 100644 --- a/flashvideo/flow_video.py +++ b/flashvideo/flow_video.py @@ -195,6 +195,66 @@ def single_function_evaluation(self, print(f'single_function_evaluation time at {t}', end_time - start_time) return vt + + + + def shared_step(self, batch): + x = self.get_input(batch) + + b, t, c, h, w = x.shape + + if h*w > (580 * 820) and t > 1: + save_memory_flag = True + else: + save_memory_flag = False + + ref_x = batch["ref_mp4"].to(self.dtype) + + if enable_aug and random.random() < self.share_cache.get("aug_prob", 0.5): + # strong pixel deg: ref_x from dataset + lr_x = ref_x + lr_x = lr_x.permute(0, 2, 1, 3, 4).contiguous() + + else: + # weak pixel deg + lr_x = x.reshape(b * t, c, h, w) + upscale_factor = self.share_cache.get("upscale_factor", 4) + + + if "scale_range" in self.share_cache: + scale_range = self.share_cache["scale_range"] # 4 ~ 8 + upscale_factor = random.uniform(scale_range[0], scale_range[1]) + + lr_x = F.interpolate(lr_x, scale_factor=1 / upscale_factor, mode="bilinear", align_corners=False, antialias=True) + lr_x = F.interpolate(lr_x, size= (h, w) , mode="bilinear", align_corners=False, antialias=True) + lr_x = lr_x.reshape(b, t, c, h, w) + lr_x = lr_x.permute(0, 2, 1, 3, 4).contiguous() + + if save_memory_flag: + lr_z = self.save_memory_encode_first_stage(lr_x, batch) + else: + lr_z = self.encode_first_stage(lr_x, batch) + + + lr_z = lr_z.permute(0, 2, 1, 3, 4).contiguous() + self.share_cache["ref_x"] = lr_z + + x = x.permute(0, 2, 1, 3, 4).contiguous() + + if save_memory_flag: + x = self.save_memory_encode_first_stage(x, batch) + else: + x = self.encode_first_stage(x, batch) + + + + x = x.permute(0, 2, 1, 3, 4).contiguous() + + loss, loss_dict = self(x, batch) + return loss, loss_dict + + + @torch.no_grad() def sample( self, @@ -274,5 +334,90 @@ def __init__(self, self.schedule = None super().__init__(**kwargs) - def __call__(self, network, denoiser, conditioner, input, batch): - pass + def __call__(self, network, denoiser, conditioner, input, batch): + cond = conditioner(batch) + additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)} + + + ref_x = self.share_cache["ref_x"] + b, num_f, c, h, w = ref_x.shape + + + + if "ref_noise_step" in self.share_cache or "ref_noise_step_range" in self.share_cache: + if "ref_noise_step" in self.share_cache and "ref_noise_step_range" not in self.share_cache: + ref_noise_step = self.share_cache["ref_noise_step"] + else: + # ref_noise_step_range 600~900 -> 600 ~ 750 + ref_noise_step_min, ref_noise_step_max = self.share_cache["ref_noise_step_range"] + if num_f == 1 and "img_ref_noise_step_range" in self.share_cache: + ref_noise_step_min, ref_noise_step_max = self.share_cache["img_ref_noise_step_range"] + local_rank, world_size = dist.get_rank(), dist.get_world_size() + local_rand_range = (ref_noise_step_max - ref_noise_step_min) // world_size + local_rand_start = local_rank * local_rand_range + ref_noise_step_min + local_rand_end = (local_rank + 1) * local_rand_range + ref_noise_step_min + ref_noise_step = torch.randint(local_rand_start, local_rand_end, (1,)).item() + + self.share_cache["sample_ref_noise_step"] = ref_noise_step + ref_alphas_cumprod_sqrt =self.sigma_sampler.idx_to_sigma( + torch.zeros(input.shape[0]).fill_(ref_noise_step).long().cpu()) + + + ref_alphas_cumprod_sqrt = ref_alphas_cumprod_sqrt.to(input.device) + ref_x = self.share_cache["ref_x"] + ref_noise = torch.randn_like(ref_x) + + # deg latent + ref_noised_input = ref_x * append_dims(ref_alphas_cumprod_sqrt, ref_x.ndim) \ + + ref_noise * append_dims( + (1 - ref_alphas_cumprod_sqrt**2) ** 0.5, ref_x.ndim + ) + self.share_cache["ref_x"] = ref_noised_input + + + bs = input.shape[0] + dev = input.device + dtype = input.dtype + + + t = torch.rand(bs, device=dev, dtype=dtype) + + + if "shift_t" in self.share_cache: + shift_t = float(self.share_cache["shift_t"]) + t = 1 - shift_t*(1-t) / (1 + (shift_t -1)*(1-t)) + + num_interval = int(self.share_cache.get("num_noise_interval", 1)) + # # split t to each device + interval = 1.0 / num_interval + + each_start = [i * interval for i in range(num_interval)] + each_start = torch.tensor(each_start, device=dev, dtype=dtype) + t = each_start + t * interval + + dummy_batch_size = num_interval + + + idx = 1000 - (t * 1000) + # sample xt and ut + x0 = self.share_cache["ref_x"] + + + + x1 = input + x0 = x0.repeat(dummy_batch_size, 1, 1, 1, 1) + x1 = x1.repeat(dummy_batch_size, 1, 1, 1, 1) + + t = pad_v_like_x(t, x0) + xt = x0 * (1 - t) + t * x1 + + + additional_model_inputs["idx"] = idx + new_cond = dict() + for k, v in cond.items(): + num_dim = v.ndim + new_cond[k] = v.repeat(dummy_batch_size, *([1] * (num_dim - 1))) + vt = network(xt, t=idx, c= new_cond, **additional_model_inputs) + + + return (vt - (x1 -x0)).square().mean() \ No newline at end of file