Skip to content

Commit 245e3f4

Browse files
authored
Merge pull request #205 from LLaVA-VL/yhzhang/video_dev
update video inference logic
2 parents c685306 + db3939e commit 245e3f4

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

llava/model/llava_arch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,10 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
253253
if vision_tower is None or images is None or input_ids.shape[1] == 1:
254254
return input_ids, position_ids, attention_mask, past_key_values, None, labels
255255

256+
if isinstance(modalities, str):
257+
modalities = [modalities]
258+
259+
# import pdb; pdb.set_trace()
256260
if type(images) is list or images.ndim == 5:
257261
if type(images) is list:
258262
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
@@ -301,6 +305,7 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
301305
# currently image_feature is a tensor of shape (4, num_patches, hidden_size)
302306
# we want to first unflatten it to (2, 2, h, w, hidden_size)
303307
# rank0_print("At least we are reaching here")
308+
# import pdb; pdb.set_trace()
304309
if image_idx in video_idx_in_batch: # video operations
305310
# rank0_print("Video")
306311
if mm_newline_position == "grid":

playground/demo/video_demo.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -67,26 +67,25 @@ def parse_args():
6767
parser.add_argument("--add_time_instruction", type=str, default=False)
6868
return parser.parse_args()
6969

70+
def load_video(video_path,args):
71+
if args.for_get_frames_num == 0:
72+
return np.zeros((1, 336, 336, 3))
73+
vr = VideoReader(video_path, ctx=cpu(0),num_threads=1)
74+
total_frame_num = len(vr)
75+
video_time = total_frame_num / vr.get_avg_fps()
76+
fps = round(vr.get_avg_fps())
77+
frame_idx = [i for i in range(0, len(vr), fps)]
78+
frame_time = [i/fps for i in frame_idx]
79+
if len(frame_idx) > args.for_get_frames_num or args.force_sample:
80+
sample_fps = args.for_get_frames_num
81+
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
82+
frame_idx = uniform_sampled_frames.tolist()
83+
frame_time = [i/vr.get_avg_fps() for i in frame_idx]
84+
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
85+
spare_frames = vr.get_batch(frame_idx).asnumpy()
86+
# import pdb;pdb.set_trace()
7087

71-
def load_video(video_path,args):
72-
if max_frames_num == 0:
73-
return np.zeros((1, 336, 336, 3))
74-
vr = VideoReader(video_path, ctx=cpu(0),num_threads=1)
75-
total_frame_num = len(vr)
76-
video_time = total_frame_num / vr.get_avg_fps()
77-
fps = round(vr.get_avg_fps()/fps)
78-
frame_idx = [i for i in range(0, len(vr), fps)]
79-
frame_time = [i/fps for i in frame_idx]
80-
if len(frame_idx) > args.for_get_frames_num or args.force_sample:
81-
sample_fps = max_frames_num
82-
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
83-
frame_idx = uniform_sampled_frames.tolist()
84-
frame_time = [i/vr.get_avg_fps() for i in frame_idx]
85-
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
86-
spare_frames = vr.get_batch(frame_idx).asnumpy()
87-
# import pdb;pdb.set_trace()
88-
89-
return spare_frames,frame_time,video_time
88+
return spare_frames,frame_time,video_time
9089

9190

9291

@@ -148,6 +147,17 @@ def run_inference(args):
148147
else:
149148
pass
150149

150+
# import pdb;pdb.set_trace()
151+
if getattr(model.config, "force_sample", None) is not None:
152+
args.force_sample = model.config.force_sample
153+
else:
154+
args.force_sample = False
155+
156+
if getattr(model.config, "add_time_instruction", None) is not None:
157+
args.add_time_instruction = model.config.add_time_instruction
158+
else:
159+
args.add_time_instruction = False
160+
151161
# Create the output directory if it doesn't exist
152162
if not os.path.exists(args.output_dir):
153163
os.makedirs(args.output_dir)
@@ -195,7 +205,7 @@ def run_inference(args):
195205
if "gpt4v" != args.model_path:
196206
qs = question
197207
if args.add_time_instruction:
198-
time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {len(video)} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video."
208+
time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video."
199209
qs = f'{time_instruciton}\n{qs}'
200210
if model.config.mm_use_im_start_end:
201211
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs

0 commit comments

Comments
 (0)