1818import json
1919import logging
2020from llava .utils import rank0_print
21- from action .utils import generate_label_map , MultiChoiceGenerator , match_answer , parse_avion_predictions
21+ from action .utils import generate_label_map , MultiChoiceGenerator , match_answer , parse_avion_predictions , avion_video_loader
2222from action .prediction_analysis import PredictionAnalysis
2323import copy
2424from collections import Counter
@@ -33,125 +33,6 @@ def datetime2sec(str):
3333 hh , mm , ss = str .split (':' )
3434 return int (hh ) * 3600 + int (mm ) * 60 + float (ss )
3535
36-
37- def get_frame_ids (start_frame , end_frame , num_segments = 32 , jitter = True ):
38- frame_ids = np .convolve (np .linspace (start_frame , end_frame , num_segments + 1 ), [0.5 , 0.5 ], mode = 'valid' )
39- if jitter :
40- seg_size = float (end_frame - start_frame - 1 ) / num_segments
41- shift = (np .random .rand (num_segments ) - 0.5 ) * seg_size
42- frame_ids += shift
43- return frame_ids .astype (int ).tolist ()
44-
45-
46- def get_video_reader (videoname , num_threads , fast_rrc , rrc_params , fast_rcc , rcc_params ):
47- video_reader = None
48- if fast_rrc :
49- video_reader = decord .VideoReader (
50- videoname ,
51- num_threads = num_threads ,
52- width = rrc_params [0 ], height = rrc_params [0 ],
53- use_rrc = True , scale_min = rrc_params [1 ][0 ], scale_max = rrc_params [1 ][1 ],
54- )
55- elif fast_rcc :
56- video_reader = decord .VideoReader (
57- videoname ,
58- num_threads = num_threads ,
59- width = rcc_params [0 ], height = rcc_params [0 ],
60- use_rcc = True ,
61- )
62- else :
63- video_reader = decord .VideoReader (videoname , num_threads = num_threads )
64- return video_reader
65-
66-
67- def video_loader (root , vid , ext , second , end_second ,
68- chunk_len = 300 , fps = 30 , clip_length = 32 ,
69- threads = 1 ,
70- fast_rrc = False , rrc_params = (224 , (0.5 , 1.0 )),
71- fast_rcc = False , rcc_params = (224 , ),
72- jitter = False ):
73- assert fps > 0 , 'fps should be greater than 0'
74- if chunk_len == - 1 :
75- vr = get_video_reader (
76- osp .join (root , '{}.{}' .format (vid , ext )),
77- num_threads = threads ,
78- fast_rrc = fast_rrc , rrc_params = rrc_params ,
79- fast_rcc = fast_rcc , rcc_params = rcc_params ,
80- )
81- end_second = min (end_second , len (vr ) / fps )
82-
83- # calculate frame_ids
84- frame_offset = int (np .round (second * fps ))
85- total_duration = max (int ((end_second - second ) * fps ), clip_length )
86- frame_ids = get_frame_ids (frame_offset , min (frame_offset + total_duration , len (vr )), num_segments = clip_length , jitter = jitter )
87-
88- # load frames
89- assert max (frame_ids ) < len (vr )
90- try :
91- frames = vr .get_batch (frame_ids ).asnumpy ()
92- except decord .DECORDError as error :
93- print (error )
94- frames = vr .get_batch ([0 ] * len (frame_ids )).asnumpy ()
95-
96- return torch .from_numpy (frames .astype (np .float32 ))
97-
98- else :
99- time_meta = {}
100-
101- time_meta ['duration' ] = end_second - second
102- chunk_start = int (second ) // chunk_len * chunk_len
103- chunk_end = int (end_second ) // chunk_len * chunk_len
104- while True :
105- video_filename = osp .join (root , '{}.{}' .format (vid , ext ), '{}.{}' .format (chunk_end , ext ))
106-
107- if not osp .exists (video_filename ):
108- # print("{} does not exists!".format(video_filename))
109- chunk_end -= chunk_len
110- else :
111- vr = decord .VideoReader (video_filename )
112- end_second = min (end_second , (len (vr ) - 1 ) / fps + chunk_end )
113- assert chunk_start <= chunk_end
114- break
115- # calculate frame_ids
116- frame_ids = get_frame_ids (
117- int (np .round (second * fps )),
118- int (np .round (end_second * fps )),
119- num_segments = clip_length , jitter = jitter
120- )
121- all_frames = []
122- all_frame_ids = []
123- # allocate absolute frame-ids into the relative ones
124- for chunk in range (chunk_start , chunk_end + chunk_len , chunk_len ):
125- rel_frame_ids = list (filter (lambda x : int (chunk * fps ) <= x < int ((chunk + chunk_len ) * fps ), frame_ids ))
126- rel_frame_ids = [int (frame_id - chunk * fps ) for frame_id in rel_frame_ids ]
127- vr = get_video_reader (
128- osp .join (root , '{}.{}' .format (vid , ext ), '{}.{}' .format (chunk , ext )),
129- num_threads = threads ,
130- fast_rrc = fast_rrc , rrc_params = rrc_params ,
131- fast_rcc = fast_rcc , rcc_params = rcc_params ,
132- )
133- try :
134- frames = vr .get_batch (rel_frame_ids ).asnumpy ()
135- except decord .DECORDError as error :
136- print (error )
137- frames = vr .get_batch ([0 ] * len (rel_frame_ids )).asnumpy ()
138- except IndexError :
139- print (root , vid , ext , second , end_second )
140- all_frames .append (frames )
141- all_frame_ids .append (frame_ids )
142- if sum (map (lambda x : x .shape [0 ], all_frames )) == clip_length :
143- break
144- res = torch .from_numpy (np .concatenate (all_frames , axis = 0 ).astype (np .float32 ))
145- time_meta ['n_frames' ] = res .shape [0 ]
146- all_frame_ids = np .concatenate (all_frame_ids , axis = 0 )
147- frame_time = [e / fps for e in all_frame_ids ]
148- frame_time -= frame_time [0 ]
149- frame_time = "," .join ([f"{ i :.2f} s" for i in frame_time ])
150- time_meta ['frame_time' ] = frame_time
151- assert res .shape [0 ] == clip_length , "{}, {}, {}, {}, {}, {}, {}" .format (root , vid , second , end_second , res .shape [0 ], rel_frame_ids , frame_ids )
152- return res , time_meta
153-
154-
15536class VideoCaptionDatasetBase (torch .utils .data .Dataset ):
15637 def __init__ (self , dataset , root , metadata , is_trimmed = True ):
15738 self .dataset = dataset
@@ -216,7 +97,7 @@ def get_raw_item(
21697 vid_path , start_second , end_second , fps , narration , verb , noun = self .samples [i ]
21798 # chunk length is the chunked video clip length
21899 # clip length is number of frames we want to sample from the clip
219- frames , time_meta = video_loader (self .root , vid_path , 'MP4' ,
100+ frames , time_meta = avion_video_loader (self .root , vid_path , 'MP4' ,
220101 start_second , end_second ,
221102 chunk_len = chunk_len , fps = fps ,
222103 clip_length = clip_length ,
0 commit comments