Skip to content

Commit a8d3c83

Browse files
author
Ye Shaokai
committed
triple supports action representation
1 parent e318fc8 commit a8d3c83

File tree

1 file changed

+24
-111
lines changed

1 file changed

+24
-111
lines changed

llava/action/generate_interval_pred.py

Lines changed: 24 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,16 @@ def sort_correspondance(vid_to_intervals, vid_to_gt_narration):
3030
return sorted_vid_to_gt_narration
3131

3232

33-
def get_annotated_intervals(file_path):
33+
def get_annotated_intervals(file_path, action_representation):
3434
csv_reader = csv.reader(open(file_path))
3535
_ = next(csv_reader)
3636
vid_to_intervals = defaultdict(list)
37-
vid_to_gt_narration = defaultdict(list)
37+
vid_to_action_representation = defaultdict(list)
3838
vid_to_action_ids = defaultdict(list)
3939

40-
labels, mapping_vn2narration, mapping_vn2act, verb_maps, noun_maps = generate_label_map(Path(file_path).parent, 'GT_random_narration')
40+
labels, mapping_vn2narration, mapping_vn2act, verb_maps, noun_maps = generate_label_map(Path(file_path).parent, action_representation)
41+
print (verb_maps)
42+
print (noun_maps)
4143
for row in csv_reader:
4244
pid, vid = row[1:3]
4345
narration = row[8]
@@ -53,15 +55,17 @@ def get_annotated_intervals(file_path):
5355
#print(f"{vid} has a long duration of action {narration} {end_timestamp - start_timestamp:.2f}")
5456

5557
vid_to_intervals[vid].append((start_timestamp, end_timestamp))
56-
vid_to_gt_narration[vid].append(narration)
57-
58+
if action_representation == 'GT_random_narration':
59+
vid_to_action_representation[vid].append(narration)
60+
elif action_representation == 'official_key':
61+
vid_to_action_representation[vid].append(f'{verb_maps[str(verb_id)]} {noun_maps[str(noun_id)]}')
5862

5963
vid_to_action_ids[vid].append((verb_id, noun_id, action_id))
6064

61-
return vid_to_intervals, vid_to_gt_narration, vid_to_action_ids
65+
return vid_to_intervals, vid_to_action_representation, vid_to_action_ids
6266

6367

64-
def build_uid_pad_dict(ann_file,
68+
def build_uid_pad_dict(ann_file,
6569
delta = 3):
6670
"""
6771
every uid corresponds to two neighboring actions
@@ -103,10 +107,8 @@ def build_uid_pad_dict(ann_file,
103107
return uid_to_neighbors
104108

105109

106-
def get_pseudo_dict(pseudo_folder, delta = 3):
107-
import glob
108-
109-
110+
def get_pseudo_dict(pseudo_folder):
111+
import glob
110112
files = glob.glob(os.path.join(pseudo_folder, 'prediction*.json'))
111113

112114
pseudo_data = {}
@@ -124,9 +126,9 @@ def get_pseudo_dict(pseudo_folder, delta = 3):
124126
assert len(ret) == len(pseudo_data)
125127
return ret
126128

127-
def get_lookup_dict(ann_file, test_type = 'base', delta = 3, pseudo_folder = None):
129+
def get_lookup_dict(ann_file, action_representation, test_type = 'base', delta = 3, pseudo_folder = None):
128130

129-
vid_to_intervals, vid_to_gt_narration, _ = get_annotated_intervals(ann_file)
131+
vid_to_intervals, vid_to_action_representation, _ = get_annotated_intervals(ann_file, action_representation)
130132
table = {}
131133

132134
pseudo_dict = None
@@ -139,7 +141,7 @@ def get_lookup_dict(ann_file, test_type = 'base', delta = 3, pseudo_folder = Non
139141
sorted_indices = sorted(range(len(intervals)), key=lambda i: intervals[i][1])
140142

141143
sorted_intervals = [intervals[i] for i in sorted_indices]
142-
sorted_narrations = [vid_to_gt_narration[vid][i] for i in sorted_indices]
144+
sorted_narrations = [vid_to_action_representation[vid][i] for i in sorted_indices]
143145

144146
end_times = [end for _, end in sorted_intervals]
145147
start_times = [start for start, _ in sorted_intervals]
@@ -182,99 +184,6 @@ def get_lookup_dict(ann_file, test_type = 'base', delta = 3, pseudo_folder = Non
182184
return table
183185

184186

185-
def sample_uid_triples(anno_file,
186-
delta = 3,
187-
question_type="triple_direct_answer"):
188-
vid_to_intervals, vid_to_gt_narration, vid_to_action_ids = get_annotated_intervals(anno_file)
189-
ret = []
190-
191-
for vid, intervals in vid_to_intervals.items():
192-
# Sort intervals by end time
193-
sorted_intervals = sorted(intervals, key=lambda x: x[1])
194-
195-
end_times = [end for _, end in sorted_intervals]
196-
start_times = [start for start, _ in sorted_intervals]
197-
198-
# Look for consecutive triples
199-
for i in range(len(sorted_intervals)-2): # -2 because we need 3 consecutive intervals
200-
id = vid.split('_')[0] + '-' + vid
201-
202-
# Get time differences between consecutive intervals
203-
time_diff1 = start_times[i+1] - end_times[i]
204-
time_diff2 = start_times[i+2] - end_times[i+1]
205-
206-
# Check if both time differences are less than 3 seconds
207-
if time_diff1 <= delta and time_diff2 <= delta:
208-
# Create UIDs for each interval in the triple
209-
uid1 = f"{id}_{round(start_times[i],2)}_{round(end_times[i],2)}"
210-
uid2 = f"{id}_{round(start_times[i+1],2)}_{round(end_times[i+1],2)}"
211-
uid3 = f"{id}_{round(start_times[i+2],2)}_{round(end_times[i+2],2)}"
212-
213-
# Get corresponding narrations
214-
verb_id1, noun_id1, action_id1 = vid_to_action_ids[vid][i]
215-
verb_id2, noun_id2, action_id2 = vid_to_action_ids[vid][i+1]
216-
verb_id3, noun_id3, action_id3 = vid_to_action_ids[vid][i+2]
217-
218-
219-
narration1 = vid_to_gt_narration[vid][i]
220-
narration2 = vid_to_gt_narration[vid][i+1]
221-
narration3 = vid_to_gt_narration[vid][i+2]
222-
223-
if question_type == "triple_multiple_choice":
224-
pass
225-
elif question_type == "triple_direct_answer":
226-
target = narration1 + ', ' + narration2 + ', ' + narration3
227-
228-
triple = {
229-
'id': id,
230-
'video': id,
231-
'start_timestamp': start_times[i],
232-
'end_timestamp': end_times[i+2],
233-
'gt_narration_triple': [narration1, narration2, narration3],
234-
'conversations': [{"from": "human", "value":""},
235-
{"from": "gpt", "value": target}
236-
],
237-
"question_type": question_type,
238-
'split': 'train',
239-
'dataset_name': 'EK100',
240-
'triple_meta':
241-
[
242-
{ 'uid': uid1,
243-
'narration': narration1,
244-
'start_timestep': start_times[i],
245-
'end_timestep': end_times[i],
246-
'duration': round(end_times[i] - start_times[i],2),
247-
'verb_id': verb_id1,
248-
'noun_id': noun_id1,
249-
'action_id': action_id1
250-
251-
},
252-
{ 'uid': uid2,
253-
'narration': narration2,
254-
'start_timestep': start_times[i+1],
255-
'end_timestep': end_times[i+1],
256-
'duration': round(end_times[i+1] - start_times[i+1],2),
257-
'verb_id': verb_id2,
258-
'noun_id': noun_id2,
259-
'action_id': action_id2
260-
},
261-
{ 'uid': uid3,
262-
'narration': narration3,
263-
'start_timestep': start_times[i+2],
264-
'end_timestep': end_times[i+2],
265-
'duration': round(end_times[i+2] - start_times[i+2],2),
266-
'verb_id': verb_id3,
267-
'noun_id': noun_id3,
268-
'action_id': action_id3
269-
}
270-
]
271-
}
272-
ret.append(triple)
273-
274-
print(f'Found {len(ret)} triples with gaps <= {delta} seconds')
275-
return ret
276-
277-
278187

279188
def create_merged_intervals(train_ann_file):
280189
"""
@@ -299,7 +208,11 @@ def create_merged_captions(triple_file, caption_file):
299208
# with open('ek100_triples.jsonl', 'w') as f:
300209
# for item in res:
301210
# f.write(json.dumps(item) + '\n')
302-
triple_file_path = 'ek100_triples.jsonl'
303-
caption_file_path = '/data/shaokai/first_person_annos/train_anno_gpt-gt-reason_4_first_person_all_action_idx.jsonl'
304-
create_merged_captions(triple_file_path, caption_file_path)
305-
211+
# triple_file_path = 'ek100_triples.jsonl'
212+
# caption_file_path = '/data/shaokai/first_person_annos/train_anno_gpt-gt-reason_4_first_person_all_action_idx.jsonl'
213+
# create_merged_captions(triple_file_path, caption_file_path)
214+
ann_file = '/data/shaokai/epic-kitchens-100-annotations/EPIC_100_train.csv'
215+
actoin_representation = 'GT_random_narration'
216+
ret = get_lookup_dict(ann_file, actoin_representation)
217+
218+
print(list(ret.items())[:10])

0 commit comments

Comments
 (0)