@@ -67,26 +67,25 @@ def parse_args():
6767 parser .add_argument ("--add_time_instruction" , type = str , default = False )
6868 return parser .parse_args ()
6969
70+ def load_video (video_path ,args ):
71+ if args .for_get_frames_num == 0 :
72+ return np .zeros ((1 , 336 , 336 , 3 ))
73+ vr = VideoReader (video_path , ctx = cpu (0 ),num_threads = 1 )
74+ total_frame_num = len (vr )
75+ video_time = total_frame_num / vr .get_avg_fps ()
76+ fps = round (vr .get_avg_fps ())
77+ frame_idx = [i for i in range (0 , len (vr ), fps )]
78+ frame_time = [i / fps for i in frame_idx ]
79+ if len (frame_idx ) > args .for_get_frames_num or args .force_sample :
80+ sample_fps = args .for_get_frames_num
81+ uniform_sampled_frames = np .linspace (0 , total_frame_num - 1 , sample_fps , dtype = int )
82+ frame_idx = uniform_sampled_frames .tolist ()
83+ frame_time = [i / vr .get_avg_fps () for i in frame_idx ]
84+ frame_time = "," .join ([f"{ i :.2f} s" for i in frame_time ])
85+ spare_frames = vr .get_batch (frame_idx ).asnumpy ()
86+ # import pdb;pdb.set_trace()
7087
71- def load_video (video_path ,args ):
72- if max_frames_num == 0 :
73- return np .zeros ((1 , 336 , 336 , 3 ))
74- vr = VideoReader (video_path , ctx = cpu (0 ),num_threads = 1 )
75- total_frame_num = len (vr )
76- video_time = total_frame_num / vr .get_avg_fps ()
77- fps = round (vr .get_avg_fps ()/ fps )
78- frame_idx = [i for i in range (0 , len (vr ), fps )]
79- frame_time = [i / fps for i in frame_idx ]
80- if len (frame_idx ) > args .for_get_frames_num or args .force_sample :
81- sample_fps = max_frames_num
82- uniform_sampled_frames = np .linspace (0 , total_frame_num - 1 , sample_fps , dtype = int )
83- frame_idx = uniform_sampled_frames .tolist ()
84- frame_time = [i / vr .get_avg_fps () for i in frame_idx ]
85- frame_time = "," .join ([f"{ i :.2f} s" for i in frame_time ])
86- spare_frames = vr .get_batch (frame_idx ).asnumpy ()
87- # import pdb;pdb.set_trace()
88-
89- return spare_frames ,frame_time ,video_time
88+ return spare_frames ,frame_time ,video_time
9089
9190
9291
@@ -148,6 +147,17 @@ def run_inference(args):
148147 else :
149148 pass
150149
150+ # import pdb;pdb.set_trace()
151+ if getattr (model .config , "force_sample" , None ) is not None :
152+ args .force_sample = model .config .force_sample
153+ else :
154+ args .force_sample = False
155+
156+ if getattr (model .config , "add_time_instruction" , None ) is not None :
157+ args .add_time_instruction = model .config .add_time_instruction
158+ else :
159+ args .add_time_instruction = False
160+
151161 # Create the output directory if it doesn't exist
152162 if not os .path .exists (args .output_dir ):
153163 os .makedirs (args .output_dir )
@@ -195,7 +205,7 @@ def run_inference(args):
195205 if "gpt4v" != args .model_path :
196206 qs = question
197207 if args .add_time_instruction :
198- time_instruciton = f"The video lasts for { video_time :.2f} seconds, and { len (video )} frames are uniformly sampled from it. These frames are located at { frame_time } .Please answer the following questions related to this video."
208+ time_instruciton = f"The video lasts for { video_time :.2f} seconds, and { len (video [ 0 ] )} frames are uniformly sampled from it. These frames are located at { frame_time } .Please answer the following questions related to this video."
199209 qs = f'{ time_instruciton } \n { qs } '
200210 if model .config .mm_use_im_start_end :
201211 qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n " + qs
0 commit comments