Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 147 additions & 2 deletions flashvideo/flow_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,66 @@ def single_function_evaluation(self,
print(f'single_function_evaluation time at {t}', end_time - start_time)
return vt




def shared_step(self, batch):
x = self.get_input(batch)

b, t, c, h, w = x.shape

if h*w > (580 * 820) and t > 1:
save_memory_flag = True
else:
save_memory_flag = False

ref_x = batch["ref_mp4"].to(self.dtype)

if enable_aug and random.random() < self.share_cache.get("aug_prob", 0.5):
# strong pixel deg: ref_x from dataset
lr_x = ref_x
lr_x = lr_x.permute(0, 2, 1, 3, 4).contiguous()

else:
# weak pixel deg
lr_x = x.reshape(b * t, c, h, w)
upscale_factor = self.share_cache.get("upscale_factor", 4)


if "scale_range" in self.share_cache:
scale_range = self.share_cache["scale_range"] # 4 ~ 8
upscale_factor = random.uniform(scale_range[0], scale_range[1])

lr_x = F.interpolate(lr_x, scale_factor=1 / upscale_factor, mode="bilinear", align_corners=False, antialias=True)
lr_x = F.interpolate(lr_x, size= (h, w) , mode="bilinear", align_corners=False, antialias=True)
lr_x = lr_x.reshape(b, t, c, h, w)
lr_x = lr_x.permute(0, 2, 1, 3, 4).contiguous()

if save_memory_flag:
lr_z = self.save_memory_encode_first_stage(lr_x, batch)
else:
lr_z = self.encode_first_stage(lr_x, batch)


lr_z = lr_z.permute(0, 2, 1, 3, 4).contiguous()
self.share_cache["ref_x"] = lr_z

x = x.permute(0, 2, 1, 3, 4).contiguous()

if save_memory_flag:
x = self.save_memory_encode_first_stage(x, batch)
else:
x = self.encode_first_stage(x, batch)



x = x.permute(0, 2, 1, 3, 4).contiguous()

loss, loss_dict = self(x, batch)
return loss, loss_dict



@torch.no_grad()
def sample(
self,
Expand Down Expand Up @@ -274,5 +334,90 @@ def __init__(self,
self.schedule = None
super().__init__(**kwargs)

def __call__(self, network, denoiser, conditioner, input, batch):
pass
def __call__(self, network, denoiser, conditioner, input, batch):
cond = conditioner(batch)
additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}


ref_x = self.share_cache["ref_x"]
b, num_f, c, h, w = ref_x.shape



if "ref_noise_step" in self.share_cache or "ref_noise_step_range" in self.share_cache:
if "ref_noise_step" in self.share_cache and "ref_noise_step_range" not in self.share_cache:
ref_noise_step = self.share_cache["ref_noise_step"]
else:
# ref_noise_step_range 600~900 -> 600 ~ 750
ref_noise_step_min, ref_noise_step_max = self.share_cache["ref_noise_step_range"]
if num_f == 1 and "img_ref_noise_step_range" in self.share_cache:
ref_noise_step_min, ref_noise_step_max = self.share_cache["img_ref_noise_step_range"]
local_rank, world_size = dist.get_rank(), dist.get_world_size()
local_rand_range = (ref_noise_step_max - ref_noise_step_min) // world_size
local_rand_start = local_rank * local_rand_range + ref_noise_step_min
local_rand_end = (local_rank + 1) * local_rand_range + ref_noise_step_min
ref_noise_step = torch.randint(local_rand_start, local_rand_end, (1,)).item()

self.share_cache["sample_ref_noise_step"] = ref_noise_step
ref_alphas_cumprod_sqrt =self.sigma_sampler.idx_to_sigma(
torch.zeros(input.shape[0]).fill_(ref_noise_step).long().cpu())


ref_alphas_cumprod_sqrt = ref_alphas_cumprod_sqrt.to(input.device)
ref_x = self.share_cache["ref_x"]
ref_noise = torch.randn_like(ref_x)

# deg latent
ref_noised_input = ref_x * append_dims(ref_alphas_cumprod_sqrt, ref_x.ndim) \
+ ref_noise * append_dims(
(1 - ref_alphas_cumprod_sqrt**2) ** 0.5, ref_x.ndim
)
self.share_cache["ref_x"] = ref_noised_input


bs = input.shape[0]
dev = input.device
dtype = input.dtype


t = torch.rand(bs, device=dev, dtype=dtype)


if "shift_t" in self.share_cache:
shift_t = float(self.share_cache["shift_t"])
t = 1 - shift_t*(1-t) / (1 + (shift_t -1)*(1-t))

num_interval = int(self.share_cache.get("num_noise_interval", 1))
# # split t to each device
interval = 1.0 / num_interval

each_start = [i * interval for i in range(num_interval)]
each_start = torch.tensor(each_start, device=dev, dtype=dtype)
t = each_start + t * interval

dummy_batch_size = num_interval


idx = 1000 - (t * 1000)
# sample xt and ut
x0 = self.share_cache["ref_x"]



x1 = input
x0 = x0.repeat(dummy_batch_size, 1, 1, 1, 1)
x1 = x1.repeat(dummy_batch_size, 1, 1, 1, 1)

t = pad_v_like_x(t, x0)
xt = x0 * (1 - t) + t * x1


additional_model_inputs["idx"] = idx
new_cond = dict()
for k, v in cond.items():
num_dim = v.ndim
new_cond[k] = v.repeat(dummy_batch_size, *([1] * (num_dim - 1)))
vt = network(xt, t=idx, c= new_cond, **additional_model_inputs)


return (vt - (x1 -x0)).square().mean()