@@ -30,14 +30,16 @@ def sort_correspondance(vid_to_intervals, vid_to_gt_narration):
3030 return sorted_vid_to_gt_narration
3131
3232
33- def get_annotated_intervals (file_path ):
33+ def get_annotated_intervals (file_path , action_representation ):
3434 csv_reader = csv .reader (open (file_path ))
3535 _ = next (csv_reader )
3636 vid_to_intervals = defaultdict (list )
37- vid_to_gt_narration = defaultdict (list )
37+ vid_to_action_representation = defaultdict (list )
3838 vid_to_action_ids = defaultdict (list )
3939
40- labels , mapping_vn2narration , mapping_vn2act , verb_maps , noun_maps = generate_label_map (Path (file_path ).parent , 'GT_random_narration' )
40+ labels , mapping_vn2narration , mapping_vn2act , verb_maps , noun_maps = generate_label_map (Path (file_path ).parent , action_representation )
41+ print (verb_maps )
42+ print (noun_maps )
4143 for row in csv_reader :
4244 pid , vid = row [1 :3 ]
4345 narration = row [8 ]
@@ -53,15 +55,17 @@ def get_annotated_intervals(file_path):
5355 #print(f"{vid} has a long duration of action {narration} {end_timestamp - start_timestamp:.2f}")
5456
5557 vid_to_intervals [vid ].append ((start_timestamp , end_timestamp ))
56- vid_to_gt_narration [vid ].append (narration )
57-
58+ if action_representation == 'GT_random_narration' :
59+ vid_to_action_representation [vid ].append (narration )
60+ elif action_representation == 'official_key' :
61+ vid_to_action_representation [vid ].append (f'{ verb_maps [str (verb_id )]} { noun_maps [str (noun_id )]} ' )
5862
5963 vid_to_action_ids [vid ].append ((verb_id , noun_id , action_id ))
6064
61- return vid_to_intervals , vid_to_gt_narration , vid_to_action_ids
65+ return vid_to_intervals , vid_to_action_representation , vid_to_action_ids
6266
6367
64- def build_uid_pad_dict (ann_file ,
68+ def build_uid_pad_dict (ann_file ,
6569 delta = 3 ):
6670 """
6771 every uid corresponds to two neighboring actions
@@ -103,10 +107,8 @@ def build_uid_pad_dict(ann_file,
103107 return uid_to_neighbors
104108
105109
106- def get_pseudo_dict (pseudo_folder , delta = 3 ):
107- import glob
108-
109-
110+ def get_pseudo_dict (pseudo_folder ):
111+ import glob
110112 files = glob .glob (os .path .join (pseudo_folder , 'prediction*.json' ))
111113
112114 pseudo_data = {}
@@ -124,9 +126,9 @@ def get_pseudo_dict(pseudo_folder, delta = 3):
124126 assert len (ret ) == len (pseudo_data )
125127 return ret
126128
127- def get_lookup_dict (ann_file , test_type = 'base' , delta = 3 , pseudo_folder = None ):
129+ def get_lookup_dict (ann_file , action_representation , test_type = 'base' , delta = 3 , pseudo_folder = None ):
128130
129- vid_to_intervals , vid_to_gt_narration , _ = get_annotated_intervals (ann_file )
131+ vid_to_intervals , vid_to_action_representation , _ = get_annotated_intervals (ann_file , action_representation )
130132 table = {}
131133
132134 pseudo_dict = None
@@ -139,7 +141,7 @@ def get_lookup_dict(ann_file, test_type = 'base', delta = 3, pseudo_folder = Non
139141 sorted_indices = sorted (range (len (intervals )), key = lambda i : intervals [i ][1 ])
140142
141143 sorted_intervals = [intervals [i ] for i in sorted_indices ]
142- sorted_narrations = [vid_to_gt_narration [vid ][i ] for i in sorted_indices ]
144+ sorted_narrations = [vid_to_action_representation [vid ][i ] for i in sorted_indices ]
143145
144146 end_times = [end for _ , end in sorted_intervals ]
145147 start_times = [start for start , _ in sorted_intervals ]
@@ -182,99 +184,6 @@ def get_lookup_dict(ann_file, test_type = 'base', delta = 3, pseudo_folder = Non
182184 return table
183185
184186
185- def sample_uid_triples (anno_file ,
186- delta = 3 ,
187- question_type = "triple_direct_answer" ):
188- vid_to_intervals , vid_to_gt_narration , vid_to_action_ids = get_annotated_intervals (anno_file )
189- ret = []
190-
191- for vid , intervals in vid_to_intervals .items ():
192- # Sort intervals by end time
193- sorted_intervals = sorted (intervals , key = lambda x : x [1 ])
194-
195- end_times = [end for _ , end in sorted_intervals ]
196- start_times = [start for start , _ in sorted_intervals ]
197-
198- # Look for consecutive triples
199- for i in range (len (sorted_intervals )- 2 ): # -2 because we need 3 consecutive intervals
200- id = vid .split ('_' )[0 ] + '-' + vid
201-
202- # Get time differences between consecutive intervals
203- time_diff1 = start_times [i + 1 ] - end_times [i ]
204- time_diff2 = start_times [i + 2 ] - end_times [i + 1 ]
205-
206- # Check if both time differences are less than 3 seconds
207- if time_diff1 <= delta and time_diff2 <= delta :
208- # Create UIDs for each interval in the triple
209- uid1 = f"{ id } _{ round (start_times [i ],2 )} _{ round (end_times [i ],2 )} "
210- uid2 = f"{ id } _{ round (start_times [i + 1 ],2 )} _{ round (end_times [i + 1 ],2 )} "
211- uid3 = f"{ id } _{ round (start_times [i + 2 ],2 )} _{ round (end_times [i + 2 ],2 )} "
212-
213- # Get corresponding narrations
214- verb_id1 , noun_id1 , action_id1 = vid_to_action_ids [vid ][i ]
215- verb_id2 , noun_id2 , action_id2 = vid_to_action_ids [vid ][i + 1 ]
216- verb_id3 , noun_id3 , action_id3 = vid_to_action_ids [vid ][i + 2 ]
217-
218-
219- narration1 = vid_to_gt_narration [vid ][i ]
220- narration2 = vid_to_gt_narration [vid ][i + 1 ]
221- narration3 = vid_to_gt_narration [vid ][i + 2 ]
222-
223- if question_type == "triple_multiple_choice" :
224- pass
225- elif question_type == "triple_direct_answer" :
226- target = narration1 + ', ' + narration2 + ', ' + narration3
227-
228- triple = {
229- 'id' : id ,
230- 'video' : id ,
231- 'start_timestamp' : start_times [i ],
232- 'end_timestamp' : end_times [i + 2 ],
233- 'gt_narration_triple' : [narration1 , narration2 , narration3 ],
234- 'conversations' : [{"from" : "human" , "value" :"" },
235- {"from" : "gpt" , "value" : target }
236- ],
237- "question_type" : question_type ,
238- 'split' : 'train' ,
239- 'dataset_name' : 'EK100' ,
240- 'triple_meta' :
241- [
242- { 'uid' : uid1 ,
243- 'narration' : narration1 ,
244- 'start_timestep' : start_times [i ],
245- 'end_timestep' : end_times [i ],
246- 'duration' : round (end_times [i ] - start_times [i ],2 ),
247- 'verb_id' : verb_id1 ,
248- 'noun_id' : noun_id1 ,
249- 'action_id' : action_id1
250-
251- },
252- { 'uid' : uid2 ,
253- 'narration' : narration2 ,
254- 'start_timestep' : start_times [i + 1 ],
255- 'end_timestep' : end_times [i + 1 ],
256- 'duration' : round (end_times [i + 1 ] - start_times [i + 1 ],2 ),
257- 'verb_id' : verb_id2 ,
258- 'noun_id' : noun_id2 ,
259- 'action_id' : action_id2
260- },
261- { 'uid' : uid3 ,
262- 'narration' : narration3 ,
263- 'start_timestep' : start_times [i + 2 ],
264- 'end_timestep' : end_times [i + 2 ],
265- 'duration' : round (end_times [i + 2 ] - start_times [i + 2 ],2 ),
266- 'verb_id' : verb_id3 ,
267- 'noun_id' : noun_id3 ,
268- 'action_id' : action_id3
269- }
270- ]
271- }
272- ret .append (triple )
273-
274- print (f'Found { len (ret )} triples with gaps <= { delta } seconds' )
275- return ret
276-
277-
278187
279188def create_merged_intervals (train_ann_file ):
280189 """
@@ -299,7 +208,11 @@ def create_merged_captions(triple_file, caption_file):
299208 # with open('ek100_triples.jsonl', 'w') as f:
300209 # for item in res:
301210 # f.write(json.dumps(item) + '\n')
302- triple_file_path = 'ek100_triples.jsonl'
303- caption_file_path = '/data/shaokai/first_person_annos/train_anno_gpt-gt-reason_4_first_person_all_action_idx.jsonl'
304- create_merged_captions (triple_file_path , caption_file_path )
305-
211+ # triple_file_path = 'ek100_triples.jsonl'
212+ # caption_file_path = '/data/shaokai/first_person_annos/train_anno_gpt-gt-reason_4_first_person_all_action_idx.jsonl'
213+ # create_merged_captions(triple_file_path, caption_file_path)
214+ ann_file = '/data/shaokai/epic-kitchens-100-annotations/EPIC_100_train.csv'
215+ actoin_representation = 'GT_random_narration'
216+ ret = get_lookup_dict (ann_file , actoin_representation )
217+
218+ print (list (ret .items ())[:10 ])
0 commit comments