Skip to content

Commit 5cfa9ec

Browse files
author
Ye Shaokai
committed
udates
1 parent b84d560 commit 5cfa9ec

File tree

2 files changed

+74
-31
lines changed

2 files changed

+74
-31
lines changed

llava/action/make_visualizations.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def save_visualization(vis_folder, frames, uid):
162162
video_out.write(bgr_frame)
163163
video_out.release()
164164

165-
def visualize_with_uid(uid):
165+
def visualize_with_uid(uid, out_folder):
166166
from llava.action.utils import avion_video_loader
167167
val_metadata = '/data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv'
168168
vid_path = '_'.join(uid.split('_')[:2]).replace('-', '/')
@@ -183,8 +183,8 @@ def visualize_with_uid(uid):
183183
fast_rcc = False,
184184
jitter = False)
185185

186-
vis_folder = f"figure1_vis"
187-
save_visualization(vis_folder, frames, uid)
186+
187+
save_visualization(out_folder, frames, uid)
188188

189189
def visualize_with_llava(pretrained_path, uid, question_type, gen_type):
190190
"""
@@ -251,6 +251,7 @@ def visualize_with_llava(pretrained_path, uid, question_type, gen_type):
251251
# llava_pretrained_path = 'experiments/LLaVA-Video-7B-Qwen2'
252252
# uid = 'P01-P01_11_182.65_192.07'
253253
# visualize_with_llava(llava_pretrained_path, uid, 'caption', 'tim')
254-
visualize_with_uid("P28-P28_16_73.84_74.66")
255-
visualize_with_uid("P28-P28_15_50.66_51.69")
256-
visualize_with_uid("P26-P26_41_113.0_114.1")
254+
# visualize_with_uid("P28-P28_16_73.84_74.66")
255+
# visualize_with_uid("P28-P28_15_50.66_51.69")
256+
# visualize_with_uid("P26-P26_41_113.0_114.1")
257+
visualize_with_uid("P28-P28_26_45.97_46.97", "key_confusing_examples")

llava/action/vis_utils.py

Lines changed: 67 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def get_uid_options_map_from_prediction_folder(uid, prediction_folder):
9999
options = value["avion_preds"]
100100
ret[uid] = options
101101
return ret
102-
def get_uid_official_map(uid, ann_file):
102+
def get_uid_official_map(ann_file):
103103
csv_reader = csv.reader(open(ann_file, 'r'))
104104
_ = next(csv_reader)
105105
anno_root = Path(ann_file).parent
@@ -124,6 +124,32 @@ def get_uid_official_map(uid, ann_file):
124124
ret[uid] = official_key
125125
return ret
126126

127+
def get_uid_narration_map(ann_file):
128+
csv_reader = csv.reader(open(ann_file, 'r'))
129+
_ = next(csv_reader)
130+
anno_root = Path(ann_file).parent
131+
labels, mapping_vn2narration, mapping_vn2act, verb_maps, noun_maps = generate_label_map(anno_root,
132+
'official_key')
133+
ret = {}
134+
for idx, row in enumerate(csv_reader):
135+
pid, vid = row[1:3]
136+
137+
start_second, end_second = datetime2sec(row[4]), datetime2sec(row[5])
138+
start_second = round(float(start_second),2)
139+
end_second = round(float(end_second),2)
140+
vid_path = '{}/{}'.format(pid, vid)
141+
left = vid_path.split('/')[0]
142+
right = vid_path.split('/')[1]
143+
uid = f'{left}-{right}_{start_second}_{end_second}'
144+
145+
verb, noun = int(row[10]), int(row[12])
146+
gt_vn = '{}:{}'.format(verb, noun)
147+
narration = row[8]
148+
official_key = verb_maps[str(verb)] + ' ' + noun_maps[str(noun)]
149+
ret[uid] = narration
150+
return ret
151+
152+
127153

128154
def get_narration_by_uid(uid, ann_file):
129155
csv_reader = csv.reader(open(ann_file, 'r'))
@@ -177,7 +203,17 @@ def get_uid_official_map(uid, ann_file):
177203
official_key = verb_maps[str(verb)] + ' ' + noun_maps[str(noun)]
178204
ret[uid] = official_key
179205
return ret
206+
207+
208+
209+
210+
def compare_caption_generation(chatgpt_file, llava_file):
211+
"Do we have llava file for this yet?"
212+
pass
180213

214+
def compare_open_ended_question_answering(chatgpt_file, llava_file):
215+
"Do we have llava file for this yet?"
216+
pass
181217

182218
def search_llavaction_win(tim_chatgpt_file,
183219
random_chatgpt_file,
@@ -202,18 +238,7 @@ def search_llavaction_win(tim_chatgpt_file,
202238
if llavaction_pred[uid]['pred'] == llavaction_pred[uid]['gt'] and \
203239
tim_chatgpt_pred[uid]['pred'] != tim_chatgpt_pred[uid]['gt'] and \
204240
llava_pred[uid]['pred'] != llava_pred[uid]['gt']:
205-
206-
# print ('uid', uid)
207-
# print ('gt', tim_chatgpt_pred[uid]['gt'])
208-
# print ('tim_chatgpt_pred', tim_chatgpt_pred[uid]['pred'])
209-
# print ('llava_pred', llava_pred[uid]['pred'])
210-
# print ('llavaction_pred', llavaction_pred[uid]['pred'])
211-
# print ('options', tim_chatgpt_options)
212-
# print ('llava_options', llava_options)
213-
# print ('llavaction_options', llavaction_options)
214-
# print ('random_chatgpt_options', random_chatgpt_options)
215-
# print ('----')
216-
# get all these printed items in the results dictionary
241+
217242
results[uid] = {'gt': tim_chatgpt_pred[uid]['gt'],
218243
'tim_chatgpt_pred': tim_chatgpt_pred[uid]['pred'],
219244
'llava_pred': llava_pred[uid]['pred'],
@@ -226,33 +251,50 @@ def search_llavaction_win(tim_chatgpt_file,
226251
with open('llavaction_win.json', 'w') as f:
227252
json.dump(results, f, indent = 4)
228253

229-
def get_wrong_prediction_uids(prediction_folder):
254+
def get_wrong_prediction_uids(prediction_folder, ann_file):
230255
"""
231256
look for where llava makes mistakes
232257
"""
233258
files = glob.glob(os.path.join(prediction_folder, '*.json'))
234-
259+
uid_narration_map = get_uid_narration_map(ann_file)
260+
235261
data = {}
236262
for file in files:
237263
with open(file, 'r') as f:
238264
data.update(json.load(f))
239265
sorted_keys = sorted(data.keys(), key = lambda k: int(k))
266+
results = {}
240267
for key in sorted_keys:
241268
value = data[key]
242269
start_second = value['start_second']
243270
end_second = value['end_second']
244271
vid_path = value['vid_path']
245272
gt_name = value['gt_name']
273+
official_key = value['gt_name']
246274
left = vid_path.split('/')[0]
247275
right = vid_path.split('/')[1]
248276
uid = f'{left}-{right}_{start_second}_{end_second}'
277+
narration = uid_narration_map[uid]
278+
279+
if value['gt_name'] not in value['avion_preds']:
280+
continue
281+
249282
if value['llava_pred'] != value['gt_name']:
250283
print ('uid', uid)
251284
print ('llava_pred', value['llava_pred'])
252-
print ('gt', gt_name)
285+
print ('official key gt', gt_name)
286+
print ('gt narration', narration)
253287
print ('options', value['avion_preds'])
254-
255-
288+
# put everything i printed in a dictionary
289+
results[uid] = {'llava_pred': value['llava_pred'],
290+
'gt': gt_name,
291+
'narration': narration,
292+
'official_key': official_key,
293+
'options': value['avion_preds']}
294+
295+
# write results to a file
296+
with open('llava_gets_confused_by_key.json', 'w') as f:
297+
json.dump(results, f, indent = 4)
256298

257299
def walk_through(ann_file):
258300
csv_reader = csv.reader(open(ann_file, 'r'))
@@ -283,20 +325,20 @@ def walk_through(ann_file):
283325
print ('narration', narration)
284326
print ('----')
285327
count+=1
286-
print ('count', count)
328+
287329

288330
if __name__ == '__main__':
289331
ann_file = '/data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv'
290-
prediction_folder = '/data/shaokai/LLaVA-NeXT/tt_dev_7b_16f_top20_full_includes_tim/'
332+
prediction_folder = '/data/shaokai/predictions_for_vis/dev_7b_16f_top5_full_includes_tim/'
291333
#walk_through(ann_file)
292-
#get_wrong_prediction_uids(prediction_folder)
334+
get_wrong_prediction_uids(prediction_folder, ann_file)
293335
root = '/data/shaokai/predictions_for_vis/'
294336
chatgpt_tim_file = os.path.join(root, 'gpt-4o-2024-08-06_tim_GT_random_narration_top5_8f_9668samples.json')
295337
chatgpt_random_file = os.path.join(root, 'gpt-4o-2024-08-06_random_GT_random_narration_top5_8f_9668samples.json')
296338
llava_zeroshot_folder = os.path.join(root, 'LLaVA_Video_7B')
297339
llavaction_folder = os.path.join(root, 'LLaVAction_7B')
298-
search_llavaction_win(chatgpt_tim_file,
299-
chatgpt_random_file,
300-
llava_zeroshot_folder,
301-
llavaction_folder)
340+
# search_llavaction_win(chatgpt_tim_file,
341+
# chatgpt_random_file,
342+
# llava_zeroshot_folder,
343+
# llavaction_folder)
302344

0 commit comments

Comments
 (0)