Skip to content

Commit 066ea45

Browse files
Refactor video loading function and add time instruction
1 parent 94c893a commit 066ea45

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

playground/demo/video_demo.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,25 +64,31 @@ def parse_args():
6464
parser.add_argument("--api_key", type=str, help="OpenAI API key")
6565
parser.add_argument("--mm_newline_position", type=str, default="no_token")
6666
parser.add_argument("--force_sample", type=lambda x: (str(x).lower() == 'true'), default=False)
67+
parser.add_argument("--add_time_instruction", type=str, default=False)
6768
return parser.parse_args()
6869

6970

70-
def load_video(video_path, args):
71-
vr = VideoReader(video_path, ctx=cpu(0))
72-
total_frame_num = len(vr)
73-
fps = round(vr.get_avg_fps())
74-
frame_idx = [i for i in range(0, len(vr), fps)]
75-
# sample_fps = args.for_get_frames_num if total_frame_num > args.for_get_frames_num else total_frame_num
76-
if len(frame_idx) > args.for_get_frames_num or args.force_sample:
77-
sample_fps = args.for_get_frames_num
78-
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
79-
frame_idx = uniform_sampled_frames.tolist()
80-
spare_frames = vr.get_batch(frame_idx).asnumpy()
81-
# Save frames as images
82-
# for i, frame in enumerate(spare_frames):
83-
# cv2.imwrite(f'{args.output_dir}/frame_{i}.jpg', cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
71+
def load_video(video_path,args):
72+
if max_frames_num == 0:
73+
return np.zeros((1, 336, 336, 3))
74+
vr = VideoReader(video_path, ctx=cpu(0),num_threads=1)
75+
total_frame_num = len(vr)
76+
video_time = total_frame_num / vr.get_avg_fps()
77+
fps = round(vr.get_avg_fps()/fps)
78+
frame_idx = [i for i in range(0, len(vr), fps)]
79+
frame_time = [i/fps for i in frame_idx]
80+
if len(frame_idx) > args.for_get_frames_num or args.force_sample:
81+
sample_fps = max_frames_num
82+
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
83+
frame_idx = uniform_sampled_frames.tolist()
84+
frame_time = [i/vr.get_avg_fps() for i in frame_idx]
85+
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
86+
spare_frames = vr.get_batch(frame_idx).asnumpy()
87+
# import pdb;pdb.set_trace()
88+
89+
return spare_frames,frame_time,video_time
90+
8491

85-
return spare_frames
8692

8793

8894
def load_video_base64(path):
@@ -177,17 +183,20 @@ def run_inference(args):
177183
# Check if the video exists
178184
if os.path.exists(video_path):
179185
if "gpt4v" != args.model_path:
180-
video = load_video(video_path, args)
186+
video,frame_time,video_time = load_video(video_path, args)
181187
video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].half().cuda()
182188
video = [video]
183189
else:
184-
video = load_video_base64(video_path)
190+
spare_frames,frame_time,video_time = load_video_base64(video_path)
185191
interval = int(len(video) / args.for_get_frames_num)
186192

187193
# try:
188194
# Run inference on the video and add the output to the list
189195
if "gpt4v" != args.model_path:
190196
qs = question
197+
if args.add_time_instruction:
198+
time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {len(video)} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video."
199+
qs = f'{time_instruciton}\n{qs}'
191200
if model.config.mm_use_im_start_end:
192201
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs
193202
else:

0 commit comments

Comments
 (0)