-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathdrop_frame.py
More file actions
126 lines (102 loc) · 3.99 KB
/
drop_frame.py
File metadata and controls
126 lines (102 loc) · 3.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import random
import torch
from decord import VideoReader
import cv2
def get_sorted_random_integers(a, b, n):
if a >= b:
raise ValueError("Lower bound 'a' must be less than upper bound 'b'")
if n < 0:
raise ValueError("'n' must be a non-negative integer")
if n > (b - a):
raise ValueError("The number of selected integers 'n' cannot exceed the total count in the range")
random_integers = random.sample(range(a, b), n)
random_integers.sort()
return random_integers
def latent_process(latent, bsz=2):
use_prob = 0.6
drop_num = 3
drop_patch_num = 7
patch_size = [3, 7]
if random.random() > use_prob:
return latent
param_dict = {
'drop_num': drop_num,
'drop_patch_num': drop_patch_num,
'patch_size': patch_size
}
def random_drop(this_latent, param_dict):
drop_num = param_dict['drop_num']
masked_frame_idxs = []
single_bz_frames_num = this_latent.shape[0] // bsz
for b in range(bsz):
masked_frame_idxs.extend(
get_sorted_random_integers(single_bz_frames_num * b, single_bz_frames_num * (b+1), drop_num)
)
this_latent[masked_frame_idxs] *= 0
return this_latent
def mid_drop(this_latent, param_dict):
masked_frame_idxs = []
single_bz_frames_num = this_latent.shape[0] // bsz
masked_frame_idxs = [0, single_bz_frames_num-1, single_bz_frames_num, this_latent.shape[0]-1]
total_idx = [i for i in range(this_latent.shape[0])]
used_idx = sorted(
set(total_idx) - set(masked_frame_idxs)
)
this_latent[used_idx] *= 0
return this_latent
def random_patch_zeroing(this_latent, param_dict):
drop_patch_num = param_dict['drop_patch_num']
patch_size = param_dict['patch_size']
frames, channels, height, width = this_latent.shape
for f in range(frames):
num_patches = random.randint(1, drop_patch_num)
for _ in range(num_patches):
patch_height = random.randint(patch_size[0], patch_size[1])
patch_width = random.randint(patch_size[0], patch_size[1])
if height - patch_height > 0:
y = random.randint(0, height - patch_height)
else:
y = 0
if width - patch_width > 0:
x = random.randint(0, width - patch_width)
else:
x = 0
this_latent[f, :, y:y+patch_height, x:x+patch_width] *= 0
return this_latent
return random.choice([
random_drop,
random_drop,
mid_drop,
])(latent, param_dict)
if __name__ == '__main__':
import decord
import torch
import numpy as np
from moviepy.editor import VideoClip
def load_video_as_tensor(video_path):
vr = decord.VideoReader(video_path)
frames = vr.get_batch(range(len(vr))).asnumpy()
frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
return frames
def tensor_to_video(tensor):
tensor = tensor.permute(0, 2, 3, 1).contiguous().numpy()
return tensor
def save_video(video_array, save_path, fps=30):
def make_frame(t):
frame_idx = int(t * fps)
if frame_idx >= len(video_array):
frame_idx = len(video_array) - 1
return video_array[frame_idx]
duration = len(video_array) / fps
clip = VideoClip(make_frame, duration=duration)
clip.write_videofile(save_path, fps=fps)
def process_video_with_func(video_path, save_path, func):
video_tensor = load_video_as_tensor(video_path)
processed_tensor = func(video_tensor)
processed_video = tensor_to_video(processed_tensor)
save_video(processed_video, save_path)
def example_func(tensor):
return tensor // 2
video_path = 'example.mp4'
save_path = 'output.mp4'
process_video_with_func(video_path, save_path, latent_process)