Skip to content

Commit 936977e

Browse files
committed
some updates
1 parent db2ad43 commit 936977e

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

action/chatgpt_utils.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
import cv2
1212
from pathlib import Path
13+
from tqdm import tqdm
1314
from action.prediction_analysis import PredictionAnalysis
1415

1516
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
@@ -186,12 +187,7 @@ def multi_process_run(self):
186187
def parse_item(self, item):
187188

188189
gt_name = item['gt_name']
189-
avion_predictions = item['avion_preds']['predictions']
190-
# _avion_predictions = [e.replace(':', ' ', 1) for e in avion_predictions]
191-
# if gt_name not in _avion_predictions:
192-
# print ('gt_name not in avion_predictions')
193-
# else:
194-
# print ('gt_name in avion_predictions')
190+
avion_predictions = item['avion_preds']['predictions']
195191

196192
vid_path = item['vid_path'][0]
197193
start_second = item['start_second']
@@ -215,7 +211,7 @@ def run(self, indices):
215211
data_batch = {i : self.data[i] for i in range(len(self.data)) if i in indices}
216212
ret = {}
217213

218-
for k,v in data_batch.items():
214+
for k,v in tqdm(data_batch.items()):
219215
parsed_item = self.parse_item(v)
220216
start_timestamp = parsed_item['start_second']
221217
end_timestamp = parsed_item['end_second']
@@ -378,7 +374,7 @@ def run(self, indices):
378374
data_batch = [self.data[i] for i in range(len(self.data)) if i in indices]
379375

380376
ret = {}
381-
for index in indices:
377+
for index in tqdm(indices):
382378
item = self.data[index]
383379
start_timestamp = item['start_timestamp']
384380
end_timestamp = item['end_timestamp']
@@ -444,36 +440,54 @@ def annotate(self, images, data_item):
444440
return response.choices[0].message.parsed
445441

446442

447-
def multi_process_annotate(train_file_path, root):
443+
def multi_process_annotate(train_file_path, root, debug = False):
448444
annotator = GPTAugmentationAnnotator(train_file_path,
449445
root,
450446
clip_length = 4,
451-
debug = True)
447+
debug = debug)
452448
results = annotator.multi_process_run()
453449

454-
def explore_wrong_examples(root, prediction_save_folder):
450+
def explore_wrong_examples(root, prediction_save_folder, debug = False):
455451
annotator = GPTInferenceAnnotator(root,
456452
prediction_save_folder,
457453
clip_length = 4,
458-
debug = True)
454+
debug = debug)
459455
annotator.explore_wrong_examples()
460456

461-
def multi_process_inference(root, prediction_save_folder):
457+
def multi_process_inference(root, prediction_save_folder, debug = False):
462458

463459
annotator = GPTInferenceAnnotator(root,
464460
prediction_save_folder,
465-
clip_length = 4,
466-
debug = True)
461+
clip_length = 32,
462+
debug = debug)
463+
467464
annotator.multi_process_run()
468465

466+
def calculate_gpt_accuracy(path):
467+
with open(path, 'r') as f:
468+
data = json.load(f)
469+
470+
keys = list(data.keys())
471+
print ('length of the data', len(keys))
472+
473+
correct_count = 0
474+
for k,v in data.items():
475+
gt_name = v['gt_name']
476+
chatgpt_answer = v['chatgpt_answer']
477+
if gt_name == chatgpt_answer:
478+
correct_count += 1
479+
else:
480+
print (chatgpt_answer, gt_name)
481+
print ('accuracy', correct_count / len(keys))
482+
469483
if __name__ == '__main__':
470484
#train_file_path = '/storage-rcp-pure/upmwmathis_scratch/shaokai/EK100_inst_train/avion_mc_top10/train_convs_narration.jsonl'
471485
#root = '/storage-rcp-pure/upmwmathis_scratch/shaokai/EK100'
472486
train_file_path = '/data/EK100_inst_train/avion_mc_top10/train_convs_narration.jsonl'
473487
root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
474488
pred_folder = '/data/epic_kitchen/llavavideo_avion_mc_top10_5epoch_preds'
475489

476-
multi_process_annotate(train_file_path, root)
490+
#multi_process_annotate(train_file_path, root)
477491
#explore_wrong_examples(root, pred_folder)
478-
479-
#multi_process_inference(root, pred_folder)
492+
#multi_process_inference(root, pred_folder, debug = False)
493+
calculate_gpt_accuracy('valset_chatgpt_inference_results/gpt-4o-avion_top10_4frames.json')

0 commit comments

Comments
 (0)