Skip to content

Commit 4b31a16

Browse files
committed
WIP
1 parent 44ec08f commit 4b31a16

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

llava/action/benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
n_frames = 4
1111
topk = 5
1212
action_representation = 'GT_random_narration'
13-
gpt_model = 'gpt-4o-mini-2024-07-18'
14-
#gpt_model = 'gpt-4o-2024-08-06'
13+
#gpt_model = 'gpt-4o-mini-2024-07-18'
14+
gpt_model = 'gpt-4o-2024-08-06'
1515
perspective = 'first_person'
1616
benchmark_testing = True
1717

llava/action/chatgpt_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,6 @@ def __init__(self,
379379
if gen_type == 'avion' or gen_type == 'tim':
380380
self.mc_generator = ActionMultiChoiceGenerator(self.annotation_root)
381381
assert os.path.exists(self.prediction_file)
382-
print ('prediction_file'*5, self.prediction_file)
383382
with open(self.prediction_file, 'r') as f:
384383
self.action_model_predictions = json.load(f)
385384
else:

llava/action/make_visualizations.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,26 @@ def visualize_with_gpt_with_tim(n_samples, question_type = 'mc_'):
6565
inferencer.multi_process_run(n_samples, disable_api_calling=False)
6666

6767

68+
def visualize_with_gpt_with_avion(n_samples, question_type = 'mc_'):
69+
"""
70+
Here we should test gpt-4o, gpt-4o-mini with different prompts
71+
"""
72+
inferencer = GPTInferenceAnnotator(gpt_model,
73+
root,
74+
annotation_file,
75+
gen_type = 'avion',
76+
prediction_file = avion_prediction_file,
77+
clip_length = n_frames,
78+
question_type = question_type,
79+
action_representation=action_representation,
80+
perspective = perspective,
81+
benchmark_testing = benchmark_testing,
82+
do_visualization = True,
83+
topk = topk)
84+
85+
inferencer.multi_process_run(n_samples, disable_api_calling=False)
86+
6887
if __name__ == '__main__':
6988
#visualize_with_random(1, question_type = "mc_")
70-
visualize_with_gpt_with_tim(1, question_type = "mc_")
89+
#visualize_with_gpt_with_tim(1, question_type = "mc_")
90+
visualize_with_gpt_with_avion(1, question_type = "mc_")

0 commit comments

Comments
 (0)