Skip to content

Commit 92053a1

Browse files
author
Ye Shaokai
committed
added script to calculate action model acc
1 parent 2f03692 commit 92053a1

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

.vscode/launch.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,8 @@
294294
"--val_metadata", "/data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv",
295295
"--llava_num_frames", "16",
296296
"--clip_length", "16",
297-
"--action_representation", "topk_narration_cut_key",
298-
"--topk_predictions", "5",
297+
"--action_representation", "official_key",
298+
"--topk_predictions", "10",
299299
"--eval_steps", "1",
300300
"--vision_supervision", "all_newlines",
301301
"--action_types", "97,300,3806",
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import json
2+
3+
tim_action_path = '/data/shaokai/TIM_PREDS/tim_pred_ids_val.json'
4+
avion_action_path = '/data/shaokai/AVION_PREDS/avion_pred_ids_val.json'
5+
6+
def calc_acc(action_path):
7+
8+
with open(action_path, 'r') as f:
9+
data = json.load(f)
10+
11+
# calculate top-1, top-5, top-10, top-20 accuracy
12+
13+
top1 = 0
14+
top5 = 0
15+
top10 = 0
16+
top20 = 0
17+
top40 = 0
18+
19+
for i in range(len(data)):
20+
preds = data[str(i)]['predictions']
21+
target = data[str(i)]['target']
22+
23+
if target in preds[:1]:
24+
top1 += 1
25+
if target in preds[:5]:
26+
top5 += 1
27+
if target in preds[:10]:
28+
top10 += 1
29+
30+
if len(preds) >= 20:
31+
if target in preds[:20]:
32+
top20 += 1
33+
34+
if len(preds) >= 40:
35+
if target in preds[:40]:
36+
top40 += 1
37+
38+
print('Top-1 accuracy:', top1/len(data))
39+
print('Top-5 accuracy:', top5/len(data))
40+
print('Top-10 accuracy:', top10/len(data))
41+
if len(preds) >= 20:
42+
print('Top-20 accuracy:', top20/len(data))
43+
if len(preds) >= 40:
44+
print('Top-40 accuracy:', top40/len(data))
45+
46+
print ('tim accs:')
47+
calc_acc(tim_action_path)
48+
49+
print ('avion accs:')
50+
calc_acc(avion_action_path)

0 commit comments

Comments
 (0)