Skip to content

Commit 6c7e92c

Browse files
author
Lu Teng
authored
[Example] Fix GPT-J accuracy check. (#356)
1 parent b112fbc commit 6c7e92c

File tree

2 files changed

+58
-37
lines changed

2 files changed

+58
-37
lines changed

example/gptj/README.md

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,9 @@ pip install jax==0.4.25 jaxlib==0.4.25 flax==0.8.2 transformers==4.38 datasets==
2020
| *```--greedy```*| *```False```*| Enable greedy search or beam search. |
2121
| *```--num-iter```*| *```10```*| Number of iterations. |
2222
| *```--num-warmup```*| *```3```*| Number of warmup iterations. |
23-
| *```--accuracy-only```*| *```False```*| Run for performance or accuracy only. |
23+
| *```--accuracy```*| *```False```*| Run accuracy check. |
2424

25-
## Accuracy Example
26-
27-
```bash
28-
export ZE_AFFINITY_MASK=0
29-
python jax_gptj.py --accuracy-only --dtype "float16"
30-
```
31-
32-
## Performance Example
25+
## Example
3326

3427
To fully utilize the hardware capabilities and achieve the best performance, you may consider setting the below ENV variables to enable our customized optimization strategies.
3528

@@ -57,3 +50,17 @@ python jax_gptj.py --input-tokens 1024 --max-new-tokens 128
5750
Inference latency: x.xxx sec.
5851
Inference throughput: x.xxx samples/sec.
5952
```
53+
54+
### Accuracy Output
55+
56+
```bash
57+
export ZE_AFFINITY_MASK=0
58+
python jax_gptj.py --input-tokens 1024 --max-new-tokens 128 --accuracy
59+
```
60+
61+
```
62+
Inference latency: x.xxx sec.
63+
Inference throughput: x.xxx samples/sec.
64+
65+
Accuracy = 1.00
66+
```

example/gptj/jax_gptj.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
parser.add_argument("--num-iter", default=10, type=int, help="num iter")
4848
parser.add_argument("--num-warmup", default=3, type=int, help="num warmup")
4949
parser.add_argument("--batch-size", default=1, type=int, help="batch size")
50-
parser.add_argument("--accuracy-only", action="store_true")
50+
parser.add_argument("--accuracy", action="store_true")
5151
args = parser.parse_args()
5252
print(args)
5353

@@ -64,7 +64,7 @@
6464
print("---- Model loading done", flush=True)
6565

6666
# input_ids
67-
if not args.accuracy_only and args.prompt is not None:
67+
if not args.accuracy and args.prompt is not None:
6868
prompt = args.prompt
6969
else:
7070
try:
@@ -91,42 +91,51 @@ def run_model(input_ids):
9191
gen_tokens = model.generate(input_ids, max_new_tokens=max_new_tokens, **generate_kwargs)
9292
return gen_tokens
9393

94-
if not args.accuracy_only:
95-
total_time = 0.0
96-
num_iter = args.num_iter
97-
num_warmup = args.num_warmup
98-
for i in range(num_iter):
99-
tic = time.time()
100-
gen_tokens = run_model(input_ids)
101-
gen_text = tokenizer.batch_decode(gen_tokens[0], skip_special_tokens=False)
102-
toc = time.time()
103-
print(gen_text, flush=True)
104-
print("Inference latency: %.3f sec." % (toc - tic), flush=True)
105-
dur = toc - tic
106-
if i >= num_warmup:
107-
total_time += dur
108-
print("\n", "-" * 10, "Summary:", "-" * 10, flush=True)
109-
latency = total_time / (num_iter - num_warmup)
110-
print("Inference latency: %.3f sec." % (latency), flush=True)
111-
print("Inference throughput: %.3f samples/sec.\n" % (args.batch_size / latency), flush=True)
112-
else:
113-
accuracy_check_combination = [[1, '32', 32], [4, '32', 32], [1, '1024', 128], [4, '1024', 128]]
114-
if [num_beams, input_tokens, max_new_tokens] not in accuracy_check_combination:
94+
total_time = 0.0
95+
num_iter = args.num_iter
96+
num_warmup = args.num_warmup
97+
for i in range(num_iter):
98+
tic = time.time()
99+
gen_tokens = run_model(input_ids)
100+
gen_text = tokenizer.batch_decode(gen_tokens[0], skip_special_tokens=False)
101+
toc = time.time()
102+
print(gen_text, flush=True)
103+
print("Inference latency: %.3f sec." % (toc - tic), flush=True)
104+
dur = toc - tic
105+
if i >= num_warmup:
106+
total_time += dur
107+
print("\n", "-" * 10, "Summary:", "-" * 10, flush=True)
108+
latency = total_time / (num_iter - num_warmup)
109+
print("Inference latency: %.3f sec." % (latency), flush=True)
110+
print("Inference throughput: %.3f samples/sec.\n" % (args.batch_size / latency), flush=True)
111+
112+
if args.accuracy:
113+
invalid_flag = False
114+
accuracy_check_combination = [['32', 32], ['32', 1], ['1024', 128], ['1024', 1]]
115+
116+
if num_beams != 1 and num_beams != 4:
117+
invalid_flag = True
118+
119+
if [input_tokens, max_new_tokens] not in accuracy_check_combination:
120+
invalid_flag = True
121+
122+
if invalid_flag:
115123
print((num_beams, input_tokens, max_new_tokens), "is not a supported combination " \
116124
"[num_beams, input_tokens, max_new_tokens] for checking accuracy")
117125
exit(-1)
126+
118127
gen_tokens = run_model(input_ids)
119128
outputs = np.array(gen_tokens[0])[0]
120129
except_outputs = []
121-
if [num_beams, input_tokens, max_new_tokens] == [1, '32', 32]:
130+
if [num_beams, input_tokens] == [1, '32']:
122131
# "She wanted to be a princess, and live in a castle. She wanted to be a mermaid, and live in the sea. She wanted to be a"
123132
except_outputs = [ 7454, 2402, 257, 640, 11, 612, 11196, 257, 2576, 11, 508,
124133
8288, 284, 423, 17545, 13, 1375, 2227, 284, 284, 4113, 290,
125134
1826, 649, 661, 11, 290, 423, 1257, 13, 2227, 284, 307,
126135
257, 21752, 11, 290, 2107, 287, 257, 16669, 1375, 2227, 284,
127136
307, 257, 4017, 23151, 11, 290, 2107, 287, 5417, 13, 1375,
128137
2227, 284, 307, 257,]
129-
elif [num_beams, input_tokens, max_new_tokens] == [4, '32', 32]:
138+
elif [num_beams, input_tokens] == [4, '32']:
130139
# "One day, she decided to go on an adventure. She packed her bags, and set off on her journey.\n\nThe little girl walked and walked,"
131140
except_outputs = [ 7454, 2402, 257, 640, 11, 612, 11196, 257, 1310,
132141
2576, 11, 508, 8288, 284, 423, 17545, 13, 1375,
@@ -135,11 +144,16 @@ def run_model(input_ids):
135144
3066, 284, 467, 319, 281, 8855, 13, 1375, 11856,
136145
607, 11668, 11, 290, 900, 572, 319, 607, 7002,
137146
13, 198, 198, 464, 1310, 2576, 6807, 290, 6807, 11]
138-
elif [num_beams, input_tokens, max_new_tokens] == [1, '1024', 128]:
147+
elif [num_beams, input_tokens] == [1, '1024']:
139148
# "Over time, the concept evolved into a game where you are the pasta, and you are trying to evolve to become the most powerful pasta. You can evolve by eating other pastas, or by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas"
140149
except_outputs = [1026, 318, 1760, 11, 290, 8948, 13, 921, 460, 711, 705, 34652, 2473, 286, 262, 309, 459, 6386, 6, 319, 5565, 11, 290, 319, 262, 3992, 13, 23911, 319, 262, 3992, 2499, 11, 475, 345, 423, 284, 29308, 3294, 3638, 329, 3084, 3867, 290, 326, 460, 307, 257, 1643, 15337, 13, 1318, 318, 257, 1256, 314, 1549, 588, 284, 1561, 546, 13, 314, 481, 467, 832, 790, 7243, 11, 916, 276, 286, 1642, 262, 7226, 644, 1816, 826, 14, 36460, 1351, 13, 26097, 14594, 625, 262, 7505, 373, 2192, 530, 286, 262, 17612, 8861, 543, 314, 550, 284, 1986, 13, 19486, 11, 314, 550, 281, 2126, 286, 644, 1611, 286, 983, 314, 2227, 284, 1205, 11, 11327, 10787, 532, 1223, 351, 257, 1256, 286, 5775, 14, 529, 669, 11, 2829, 9382, 11, 3863, 900, 287, 2272, 11, 6856, 422, 257, 1353, 12, 2902, 1570, 13, 314, 373, 6563, 326, 314, 714, 4197, 597, 7505, 1088, 340, 13, 554, 262, 886, 11, 262, 1917, 351, 257, 7505, 588, 705, 15200, 2122, 6, 287, 257, 983, 318, 326, 6954, 318, 555, 42191, 13, 632, 4325, 832, 1811, 9775, 4738, 23005, 625, 640, 11, 351, 262, 749, 15409, 9943, 7094, 16997, 13, 770, 8513, 1097, 35375, 318, 11, 287, 616, 4459, 11, 257, 1049, 1672, 286, 4036, 6954, 286, 257, 4693, 6476, 257, 4427, 13, 887, 318, 340, 257, 983, 30, 554, 257, 983, 11, 345, 761, 284, 1630, 1223, 284, 3151, 281, 9432, 13, 1320, 1630, 2925, 1028, 644, 6954, 318, 4385, 284, 307, 588, 13, 1002, 345, 1249, 262, 2836, 284, 2298, 703, 284, 18101, 1223, 11, 340, 338, 407, 6954, 7471, 532, 340, 338, 262, 7548, 286, 12661, 1486, 11, 262, 277, 540, 15646, 416, 6282, 1023, 284, 5249, 262, 2126, 286, 6954, 13, 11204, 556, 43758, 290, 257, 11303, 1878, 3699, 11, 326, 338, 407, 1223, 326, 31862, 502, 262, 826, 835, 13, 16227, 11, 616, 4094, 288, 8270, 2611, 618, 14615, 644, 284, 2251, 373, 407, 351, 644, 314, 2227, 284, 2251, 11, 475, 351, 644, 314, 750, 407, 13, 314, 1422, 470, 765, 284, 2251, 281, 705, 600, 32940, 1486, 6, 35375, 290, 31238, 869, 340, 6954, 13, 770, 318, 257, 1917, 11, 286, 1781, 11, 790, 584, 44047, 635, 550, 284, 1986, 13, 843, 22989, 416, 262, 12784, 8948, 11, 407, 867, 5257, 284, 670, 1088, 340, 13, 314, 1549, 910, 262, 691, 1103, 4610, 373, 832, 262, 779, 286, 11666, 6356, 11, 7599, 13, 1406, 1290, 11, 314, 4398, 470, 1775, 597, 5726, 1262, 428, 379, 663, 4755, 11327, 13, 45315, 11, 428, 318, 655, 257, 1257, 5449, 290, 706, 257, 981, 314, 3066, 407, 284, 307, 355, 7646, 351, 262, 983, 2126, 11, 290, 3142, 3589, 284, 2298, 4232, 314, 1807, 561, 670, 503, 13, 2011, 4238, 2126, 373, 284, 2251, 1223, 810, 9265, 3088, 284, 18101, 284, 257, 1306, 1241, 11, 475, 550, 617, 1611, 286, 22156, 2111, 284, 2245, 606, 422, 1804, 523, 13, 314, 1611, 286, 550, 428, 2939, 286, 1692, 15625, 7348, 287, 2272, 3371, 257, 937, 21446, 393, 257, 2272, 5156, 357, 439, 1912, 287, 5878, 25, 317, 4687, 28032, 286, 1781, 8, 475, 314, 3521, 470, 892, 286, 13206, 357, 961, 25, 2726, 8, 12933, 329, 326, 13, 29004, 82, 547, 616, 1306, 12141, 11, 355, 511, 2187, 14078, 4197, 2495, 880, 656, 262, 6954, 7505, 13, 887, 703, 284, 787, 340, 670, 30, 4231, 345, 262, 275, 2398, 11, 393, 4330, 262, 29004, 30, 383, 2368, 290, 2457, 2126, 1625, 284, 502, 832, 616, 11077, 11, 508, 7599, 2921, 502, 262, 2126, 286, 1642, 1223, 546, 262, 6954, 286, 11303, 64, 13, 383, 517, 314, 1807, 546, 340, 262, 517, 340, 14846, 588, 340, 561, 670, 11, 523, 314, 3066, 284, 467, 351, 340, 13, 32200, 602, 351, 616, 20886, 763, 12, 28816, 371, 516, 20342, 357, 8727, 635, 2727, 262, 705, 28452, 36684, 4698, 22242, 6, 9877, 11112, 329, 616, 493, 4951, 8, 2252, 48491, 262, 3721, 11, 355, 340, 2950, 656, 262, 2126, 286, 1719, 1981, 5207, 286, 26296, 7348, 1088, 290, 2111, 284, 18101, 1566, 484, 2627, 477, 12, 44548, 13, 317, 9233, 2126, 994, 373, 326, 262, 983, 561, 670, 284, 4727, 703, 262, 19903, 1338, 35812, 12635, 1625, 284, 2152, 532, 416, 21568, 422, 257, 3487, 8073, 3084, 13, 1406, 262, 2126, 12572, 517, 393, 1342, 656, 428, 25, 345, 389, 5586, 257, 3084, 13, 921, 423, 534, 898, 7480, 11, 351, 318, 534, 705, 8692, 4458, 1318, 389, 642, 584, 10650, 379, 262, 3084, 11, 1123, 351, 511, 898, 7480, 13, 3406, 7480, 460, 10922, 1310, 5207, 286, 26296, 13, 921, 466, 523, 416, 705, 34555, 6, 606, 832, 257, 6859, 13, 2773, 1613, 292, 389, 1365, 621, 1854, 26, 617, 389, 5443, 11, 617, 389, 7387, 13, 1119, 423, 15874, 705, 15805, 82, 3256, 543, 389, 1915, 863, 422, 534, 10824, 357, 5832, 923, 351, 257, 1271, 286, 10824, 737, 4874, 29013, 11, 534, 1613, 292, 923, 7348, 1088, 13, 5334, 13311, 318, 284, 6129, 284, 584, 13375, 11, 287, 1502, 284, 23875, 606, 357, 1169, 9432, 286, 262, 983, 318, 1719, 534, 26296, 23875, 477, 262, 13375, 319, 262, 3084, 737, 887, 484, 389, 1107, 18284, 11, 523, 706, 852, 29013, 11, 345, 423, 645, 1630, 625, 534, 26296, 357, 14925, 22875, 32, 393, 6706, 43, 49100, 737, 3406, 26296, 1595, 470, 588, 584, 661, 338, 26296, 11, 523, 611, 484, 1826, 11, 484, 2686, 10746, 379, 1123, 584, 1566, 530, 10564, 13, 921, 651, 10824, 329, 584, 1613, 292, 534, 898, 26296, 1494, 13, 4874, 257, 26296, 318, 287, 262, 25980, 286, 257, 7480, 11, 340, 4940, 49977, 340, 329, 663, 1074, 13, 632, 2753, 1088, 838, 4201, 329, 257, 7480, 284, 307, 29346, 26, 1342, 611, 517, 26296, 422, 262, 976, 1074, 389, 1088, 13, 1002, 26296, 422, 584, 1074, 389, 1088, 11, 996, 11, 484, 651, 8970, 866, 287, 511, 2230, 11, 5906, 284, 23875, 262, 7480, 11, 1566, 530, 286, 606, 4656, 357, 14925, 30193, 338, 3210, 705, 3103, 6138, 6, 4235, 737, 921, 651, 2173, 790, 1218, 329, 790, 7480, 345, 898, 13, 3827, 640, 11, 262, 3721, 12572, 656, 257, 983, 810, 345, 389, 262, 26296, 11, 290, 345, 389, 2111, 284, 18101, 284, 1716, 262, 749, 3665, 26296, 13, 921, 460, 18101, 416, 6600, 584, 1613, 292, 11, 393, 416, 5170, 584, 1613, 292, 13, 921, 460, 635, 18101, 416, 5170, 584, 1613, 292, 13, 921, 460, 635, 18101, 416, 5170, 584, 1613, 292, 13, 921, 460, 635, 18101, 416, 5170, 584, 1613, 292, 13, 921, 460, 635, 18101, 416, 5170, 584, 1613, 292, 13, 921, 460, 635, 18101, 416, 5170, 584, 1613, 292, 13, 921, 460, 635, 18101, 416, 5170, 584, 1613, 292, 13, 921, 460, 635, 18101, 416, 5170, 584, 1613, 292, 13, 921, 460, 635, 18101, 416, 5170, 584, 1613, 292, 13, 921, 460, 635, 18101, 416, 5170, 584, 1613, 292]
141150
else:
142151
# "Over time, the concept evolved into this: you are sitting a table. You have your own plate, with is your 'base'. There are 5 other guests at the table, each with their own plate. Your plate can spawn little pieces of pasta. You do so by 'ordering' them through a menu. Some pastas are better than others; some are faster, some are stronger. They have varying 'costs', which are debited from your credits (you start with a number of credits). Once spawned, your pastas start flying around. Their instinct is to fly to other plates, in order to conquer them (the objective of the game is"
143152
except_outputs = [1026, 318, 1760, 11, 290, 8948, 13, 921, 460, 711, 705, 34652, 2473, 286, 262, 309, 459, 6386, 6, 319, 5565, 11, 290, 319, 262, 3992, 13, 23911, 319, 262, 3992, 2499, 11, 475, 345, 423, 284, 29308, 3294, 3638, 329, 3084, 3867, 290, 326, 460, 307, 257, 1643, 15337, 13, 1318, 318, 257, 1256, 314, 1549, 588, 284, 1561, 546, 13, 314, 481, 467, 832, 790, 7243, 11, 916, 276, 286, 1642, 262, 7226, 644, 1816, 826, 14, 36460, 1351, 13, 26097, 14594, 625, 262, 7505, 373, 2192, 530, 286, 262, 17612, 8861, 543, 314, 550, 284, 1986, 13, 19486, 11, 314, 550, 281, 2126, 286, 644, 1611, 286, 983, 314, 2227, 284, 1205, 11, 11327, 10787, 532, 1223, 351, 257, 1256, 286, 5775, 14, 529, 669, 11, 2829, 9382, 11, 3863, 900, 287, 2272, 11, 6856, 422, 257, 1353, 12, 2902, 1570, 13, 314, 373, 6563, 326, 314, 714, 4197, 597, 7505, 1088, 340, 13, 554, 262, 886, 11, 262, 1917, 351, 257, 7505, 588, 705, 15200, 2122, 6, 287, 257, 983, 318, 326, 6954, 318, 555, 42191, 13, 632, 4325, 832, 1811, 9775, 4738, 23005, 625, 640, 11, 351, 262, 749, 15409, 9943, 7094, 16997, 13, 770, 8513, 1097, 35375, 318, 11, 287, 616, 4459, 11, 257, 1049, 1672, 286, 4036, 6954, 286, 257, 4693, 6476, 257, 4427, 13, 887, 318, 340, 257, 983, 30, 554, 257, 983, 11, 345, 761, 284, 1630, 1223, 284, 3151, 281, 9432, 13, 1320, 1630, 2925, 1028, 644, 6954, 318, 4385, 284, 307, 588, 13, 1002, 345, 1249, 262, 2836, 284, 2298, 703, 284, 18101, 1223, 11, 340, 338, 407, 6954, 7471, 532, 340, 338, 262, 7548, 286, 12661, 1486, 11, 262, 277, 540, 15646, 416, 6282, 1023, 284, 5249, 262, 2126, 286, 6954, 13, 11204, 556, 43758, 290, 257, 11303, 1878, 3699, 11, 326, 338, 407, 1223, 326, 31862, 502, 262, 826, 835, 13, 16227, 11, 616, 4094, 288, 8270, 2611, 618, 14615, 644, 284, 2251, 373, 407, 351, 644, 314, 2227, 284, 2251, 11, 475, 351, 644, 314, 750, 407, 13, 314, 1422, 470, 765, 284, 2251, 281, 705, 600, 32940, 1486, 6, 35375, 290, 31238, 869, 340, 6954, 13, 770, 318, 257, 1917, 11, 286, 1781, 11, 790, 584, 44047, 635, 550, 284, 1986, 13, 843, 22989, 416, 262, 12784, 8948, 11, 407, 867, 5257, 284, 670, 1088, 340, 13, 314, 1549, 910, 262, 691, 1103, 4610, 373, 832, 262, 779, 286, 11666, 6356, 11, 7599, 13, 1406, 1290, 11, 314, 4398, 470, 1775, 597, 5726, 1262, 428, 379, 663, 4755, 11327, 13, 45315, 11, 428, 318, 655, 257, 1257, 5449, 290, 706, 257, 981, 314, 3066, 407, 284, 307, 355, 7646, 351, 262, 983, 2126, 11, 290, 3142, 3589, 284, 2298, 4232, 314, 1807, 561, 670, 503, 13, 2011, 4238, 2126, 373, 284, 2251, 1223, 810, 9265, 3088, 284, 18101, 284, 257, 1306, 1241, 11, 475, 550, 617, 1611, 286, 22156, 2111, 284, 2245, 606, 422, 1804, 523, 13, 314, 1611, 286, 550, 428, 2939, 286, 1692, 15625, 7348, 287, 2272, 3371, 257, 937, 21446, 393, 257, 2272, 5156, 357, 439, 1912, 287, 5878, 25, 317, 4687, 28032, 286, 1781, 8, 475, 314, 3521, 470, 892, 286, 13206, 357, 961, 25, 2726, 8, 12933, 329, 326, 13, 29004, 82, 547, 616, 1306, 12141, 11, 355, 511, 2187, 14078, 4197, 2495, 880, 656, 262, 6954, 7505, 13, 887, 703, 284, 787, 340, 670, 30, 4231, 345, 262, 275, 2398, 11, 393, 4330, 262, 29004, 30, 383, 2368, 290, 2457, 2126, 1625, 284, 502, 832, 616, 11077, 11, 508, 7599, 2921, 502, 262, 2126, 286, 1642, 1223, 546, 262, 6954, 286, 11303, 64, 13, 383, 517, 314, 1807, 546, 340, 262, 517, 340, 14846, 588, 340, 561, 670, 11, 523, 314, 3066, 284, 467, 351, 340, 13, 32200, 602, 351, 616, 20886, 763, 12, 28816, 371, 516, 20342, 357, 8727, 635, 2727, 262, 705, 28452, 36684, 4698, 22242, 6, 9877, 11112, 329, 616, 493, 4951, 8, 2252, 48491, 262, 3721, 11, 355, 340, 2950, 656, 262, 2126, 286, 1719, 1981, 5207, 286, 26296, 7348, 1088, 290, 2111, 284, 18101, 1566, 484, 2627, 477, 12, 44548, 13, 317, 9233, 2126, 994, 373, 326, 262, 983, 561, 670, 284, 4727, 703, 262, 19903, 1338, 35812, 12635, 1625, 284, 2152, 532, 416, 21568, 422, 257, 3487, 8073, 3084, 13, 1406, 262, 2126, 12572, 517, 393, 1342, 656, 428, 25, 345, 389, 5586, 257, 3084, 13, 921, 423, 534, 898, 7480, 11, 351, 318, 534, 705, 8692, 4458, 1318, 389, 642, 584, 10650, 379, 262, 3084, 11, 1123, 351, 511, 898, 7480, 13, 3406, 7480, 460, 10922, 1310, 5207, 286, 26296, 13, 921, 466, 523, 416, 705, 34555, 6, 606, 832, 257, 6859, 13, 2773, 1613, 292, 389, 1365, 621, 1854, 26, 617, 389, 5443, 11, 617, 389, 7387, 13, 1119, 423, 15874, 705, 15805, 82, 3256, 543, 389, 1915, 863, 422, 534, 10824, 357, 5832, 923, 351, 257, 1271, 286, 10824, 737, 4874, 29013, 11, 534, 1613, 292, 923, 7348, 1088, 13, 5334, 13311, 318, 284, 6129, 284, 584, 13375, 11, 287, 1502, 284, 23875, 606, 357, 1169, 9432, 286, 262, 983, 318, 1719, 534, 26296, 23875, 477, 262, 13375, 319, 262, 3084, 737, 887, 484, 389, 1107, 18284, 11, 523, 706, 852, 29013, 11, 345, 423, 645, 1630, 625, 534, 26296, 357, 14925, 22875, 32, 393, 6706, 43, 49100, 737, 3406, 26296, 1595, 470, 588, 584, 661, 338, 26296, 11, 523, 611, 484, 1826, 11, 484, 2686, 10746, 379, 1123, 584, 1566, 530, 10564, 13, 921, 651, 10824, 329, 584, 1613, 292, 534, 898, 26296, 1494, 13, 4874, 257, 26296, 318, 287, 262, 25980, 286, 257, 7480, 11, 340, 4940, 49977, 340, 329, 663, 1074, 13, 632, 2753, 1088, 838, 4201, 329, 257, 7480, 284, 307, 29346, 26, 1342, 611, 517, 26296, 422, 262, 976, 1074, 389, 1088, 13, 1002, 26296, 422, 584, 1074, 389, 1088, 11, 996, 11, 484, 651, 8970, 866, 287, 511, 2230, 11, 5906, 284, 23875, 262, 7480, 11, 1566, 530, 286, 606, 4656, 357, 14925, 30193, 338, 3210, 705, 3103, 6138, 6, 4235, 737, 921, 651, 2173, 790, 1218, 329, 790, 7480, 345, 898, 13, 3827, 640, 11, 262, 3721, 12572, 656, 428, 25, 345, 389, 5586, 257, 3084, 13, 921, 423, 534, 898, 7480, 11, 351, 318, 534, 705, 8692, 4458, 1318, 389, 642, 584, 10650, 379, 262, 3084, 11, 1123, 351, 511, 898, 7480, 13, 3406, 7480, 460, 10922, 1310, 5207, 286, 26296, 13, 921, 466, 523, 416, 705, 34555, 6, 606, 832, 257, 6859, 13, 2773, 1613, 292, 389, 1365, 621, 1854, 26, 617, 389, 5443, 11, 617, 389, 7387, 13, 1119, 423, 15874, 705, 15805, 82, 3256, 543, 389, 1915, 863, 422, 534, 10824, 357, 5832, 923, 351, 257, 1271, 286, 10824, 737, 4874, 29013, 11, 534, 1613, 292, 923, 7348, 1088, 13, 5334, 13311, 318, 284, 6129, 284, 584, 13375, 11, 287, 1502, 284, 23875, 606, 357, 1169, 9432, 286, 262, 983, 318]
144-
acc = (outputs == except_outputs).sum().item()/except_outputs.size
145-
print("Accuracy = {}".format(acc))
153+
154+
if max_new_tokens == 1:
155+
index = int(input_tokens)
156+
acc = (outputs[index] == except_outputs[index])
157+
else:
158+
acc = (outputs == except_outputs).sum().item()/len(except_outputs)
159+
print("Accuracy = {:.2f}".format(acc))

0 commit comments

Comments
 (0)