@@ -64,25 +64,31 @@ def parse_args():
6464 parser .add_argument ("--api_key" , type = str , help = "OpenAI API key" )
6565 parser .add_argument ("--mm_newline_position" , type = str , default = "no_token" )
6666 parser .add_argument ("--force_sample" , type = lambda x : (str (x ).lower () == 'true' ), default = False )
67+ parser .add_argument ("--add_time_instruction" , type = str , default = False )
6768 return parser .parse_args ()
6869
6970
70- def load_video (video_path , args ):
71- vr = VideoReader (video_path , ctx = cpu (0 ))
72- total_frame_num = len (vr )
73- fps = round (vr .get_avg_fps ())
74- frame_idx = [i for i in range (0 , len (vr ), fps )]
75- # sample_fps = args.for_get_frames_num if total_frame_num > args.for_get_frames_num else total_frame_num
76- if len (frame_idx ) > args .for_get_frames_num or args .force_sample :
77- sample_fps = args .for_get_frames_num
78- uniform_sampled_frames = np .linspace (0 , total_frame_num - 1 , sample_fps , dtype = int )
79- frame_idx = uniform_sampled_frames .tolist ()
80- spare_frames = vr .get_batch (frame_idx ).asnumpy ()
81- # Save frames as images
82- # for i, frame in enumerate(spare_frames):
83- # cv2.imwrite(f'{args.output_dir}/frame_{i}.jpg', cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
71+ def load_video (video_path ,args ):
72+ if max_frames_num == 0 :
73+ return np .zeros ((1 , 336 , 336 , 3 ))
74+ vr = VideoReader (video_path , ctx = cpu (0 ),num_threads = 1 )
75+ total_frame_num = len (vr )
76+ video_time = total_frame_num / vr .get_avg_fps ()
77+ fps = round (vr .get_avg_fps ()/ fps )
78+ frame_idx = [i for i in range (0 , len (vr ), fps )]
79+ frame_time = [i / fps for i in frame_idx ]
80+ if len (frame_idx ) > args .for_get_frames_num or args .force_sample :
81+ sample_fps = max_frames_num
82+ uniform_sampled_frames = np .linspace (0 , total_frame_num - 1 , sample_fps , dtype = int )
83+ frame_idx = uniform_sampled_frames .tolist ()
84+ frame_time = [i / vr .get_avg_fps () for i in frame_idx ]
85+ frame_time = "," .join ([f"{ i :.2f} s" for i in frame_time ])
86+ spare_frames = vr .get_batch (frame_idx ).asnumpy ()
87+ # import pdb;pdb.set_trace()
88+
89+ return spare_frames ,frame_time ,video_time
90+
8491
85- return spare_frames
8692
8793
8894def load_video_base64 (path ):
@@ -177,17 +183,20 @@ def run_inference(args):
177183 # Check if the video exists
178184 if os .path .exists (video_path ):
179185 if "gpt4v" != args .model_path :
180- video = load_video (video_path , args )
186+ video , frame_time , video_time = load_video (video_path , args )
181187 video = image_processor .preprocess (video , return_tensors = "pt" )["pixel_values" ].half ().cuda ()
182188 video = [video ]
183189 else :
184- video = load_video_base64 (video_path )
190+ spare_frames , frame_time , video_time = load_video_base64 (video_path )
185191 interval = int (len (video ) / args .for_get_frames_num )
186192
187193 # try:
188194 # Run inference on the video and add the output to the list
189195 if "gpt4v" != args .model_path :
190196 qs = question
197+ if args .add_time_instruction :
198+ time_instruciton = f"The video lasts for { video_time :.2f} seconds, and { len (video )} frames are uniformly sampled from it. These frames are located at { frame_time } .Please answer the following questions related to this video."
199+ qs = f'{ time_instruciton } \n { qs } '
191200 if model .config .mm_use_im_start_end :
192201 qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n " + qs
193202 else :
0 commit comments