47
47
parser .add_argument ("--num-iter" , default = 10 , type = int , help = "num iter" )
48
48
parser .add_argument ("--num-warmup" , default = 3 , type = int , help = "num warmup" )
49
49
parser .add_argument ("--batch-size" , default = 1 , type = int , help = "batch size" )
50
- parser .add_argument ("--accuracy-only " , action = "store_true" )
50
+ parser .add_argument ("--accuracy" , action = "store_true" )
51
51
args = parser .parse_args ()
52
52
print (args )
53
53
64
64
print ("---- Model loading done" , flush = True )
65
65
66
66
# input_ids
67
- if not args .accuracy_only and args .prompt is not None :
67
+ if not args .accuracy and args .prompt is not None :
68
68
prompt = args .prompt
69
69
else :
70
70
try :
@@ -91,42 +91,51 @@ def run_model(input_ids):
91
91
gen_tokens = model .generate (input_ids , max_new_tokens = max_new_tokens , ** generate_kwargs )
92
92
return gen_tokens
93
93
94
- if not args .accuracy_only :
95
- total_time = 0.0
96
- num_iter = args .num_iter
97
- num_warmup = args .num_warmup
98
- for i in range (num_iter ):
99
- tic = time .time ()
100
- gen_tokens = run_model (input_ids )
101
- gen_text = tokenizer .batch_decode (gen_tokens [0 ], skip_special_tokens = False )
102
- toc = time .time ()
103
- print (gen_text , flush = True )
104
- print ("Inference latency: %.3f sec." % (toc - tic ), flush = True )
105
- dur = toc - tic
106
- if i >= num_warmup :
107
- total_time += dur
108
- print ("\n " , "-" * 10 , "Summary:" , "-" * 10 , flush = True )
109
- latency = total_time / (num_iter - num_warmup )
110
- print ("Inference latency: %.3f sec." % (latency ), flush = True )
111
- print ("Inference throughput: %.3f samples/sec.\n " % (args .batch_size / latency ), flush = True )
112
- else :
113
- accuracy_check_combination = [[1 , '32' , 32 ], [4 , '32' , 32 ], [1 , '1024' , 128 ], [4 , '1024' , 128 ]]
114
- if [num_beams , input_tokens , max_new_tokens ] not in accuracy_check_combination :
94
+ total_time = 0.0
95
+ num_iter = args .num_iter
96
+ num_warmup = args .num_warmup
97
+ for i in range (num_iter ):
98
+ tic = time .time ()
99
+ gen_tokens = run_model (input_ids )
100
+ gen_text = tokenizer .batch_decode (gen_tokens [0 ], skip_special_tokens = False )
101
+ toc = time .time ()
102
+ print (gen_text , flush = True )
103
+ print ("Inference latency: %.3f sec." % (toc - tic ), flush = True )
104
+ dur = toc - tic
105
+ if i >= num_warmup :
106
+ total_time += dur
107
+ print ("\n " , "-" * 10 , "Summary:" , "-" * 10 , flush = True )
108
+ latency = total_time / (num_iter - num_warmup )
109
+ print ("Inference latency: %.3f sec." % (latency ), flush = True )
110
+ print ("Inference throughput: %.3f samples/sec.\n " % (args .batch_size / latency ), flush = True )
111
+
112
+ if args .accuracy :
113
+ invalid_flag = False
114
+ accuracy_check_combination = [['32' , 32 ], ['32' , 1 ], ['1024' , 128 ], ['1024' , 1 ]]
115
+
116
+ if num_beams != 1 and num_beams != 4 :
117
+ invalid_flag = True
118
+
119
+ if [input_tokens , max_new_tokens ] not in accuracy_check_combination :
120
+ invalid_flag = True
121
+
122
+ if invalid_flag :
115
123
print ((num_beams , input_tokens , max_new_tokens ), "is not a supported combination " \
116
124
"[num_beams, input_tokens, max_new_tokens] for checking accuracy" )
117
125
exit (- 1 )
126
+
118
127
gen_tokens = run_model (input_ids )
119
128
outputs = np .array (gen_tokens [0 ])[0 ]
120
129
except_outputs = []
121
- if [num_beams , input_tokens , max_new_tokens ] == [1 , '32' , 32 ]:
130
+ if [num_beams , input_tokens ] == [1 , '32' ]:
122
131
# "She wanted to be a princess, and live in a castle. She wanted to be a mermaid, and live in the sea. She wanted to be a"
123
132
except_outputs = [ 7454 , 2402 , 257 , 640 , 11 , 612 , 11196 , 257 , 2576 , 11 , 508 ,
124
133
8288 , 284 , 423 , 17545 , 13 , 1375 , 2227 , 284 , 284 , 4113 , 290 ,
125
134
1826 , 649 , 661 , 11 , 290 , 423 , 1257 , 13 , 2227 , 284 , 307 ,
126
135
257 , 21752 , 11 , 290 , 2107 , 287 , 257 , 16669 , 1375 , 2227 , 284 ,
127
136
307 , 257 , 4017 , 23151 , 11 , 290 , 2107 , 287 , 5417 , 13 , 1375 ,
128
137
2227 , 284 , 307 , 257 ,]
129
- elif [num_beams , input_tokens , max_new_tokens ] == [4 , '32' , 32 ]:
138
+ elif [num_beams , input_tokens ] == [4 , '32' ]:
130
139
# "One day, she decided to go on an adventure. She packed her bags, and set off on her journey.\n\nThe little girl walked and walked,"
131
140
except_outputs = [ 7454 , 2402 , 257 , 640 , 11 , 612 , 11196 , 257 , 1310 ,
132
141
2576 , 11 , 508 , 8288 , 284 , 423 , 17545 , 13 , 1375 ,
@@ -135,11 +144,16 @@ def run_model(input_ids):
135
144
3066 , 284 , 467 , 319 , 281 , 8855 , 13 , 1375 , 11856 ,
136
145
607 , 11668 , 11 , 290 , 900 , 572 , 319 , 607 , 7002 ,
137
146
13 , 198 , 198 , 464 , 1310 , 2576 , 6807 , 290 , 6807 , 11 ]
138
- elif [num_beams , input_tokens , max_new_tokens ] == [1 , '1024' , 128 ]:
147
+ elif [num_beams , input_tokens ] == [1 , '1024' ]:
139
148
# "Over time, the concept evolved into a game where you are the pasta, and you are trying to evolve to become the most powerful pasta. You can evolve by eating other pastas, or by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas. You can also evolve by killing other pastas"
140
149
except_outputs = [1026 , 318 , 1760 , 11 , 290 , 8948 , 13 , 921 , 460 , 711 , 705 , 34652 , 2473 , 286 , 262 , 309 , 459 , 6386 , 6 , 319 , 5565 , 11 , 290 , 319 , 262 , 3992 , 13 , 23911 , 319 , 262 , 3992 , 2499 , 11 , 475 , 345 , 423 , 284 , 29308 , 3294 , 3638 , 329 , 3084 , 3867 , 290 , 326 , 460 , 307 , 257 , 1643 , 15337 , 13 , 1318 , 318 , 257 , 1256 , 314 , 1549 , 588 , 284 , 1561 , 546 , 13 , 314 , 481 , 467 , 832 , 790 , 7243 , 11 , 916 , 276 , 286 , 1642 , 262 , 7226 , 644 , 1816 , 826 , 14 , 36460 , 1351 , 13 , 26097 , 14594 , 625 , 262 , 7505 , 373 , 2192 , 530 , 286 , 262 , 17612 , 8861 , 543 , 314 , 550 , 284 , 1986 , 13 , 19486 , 11 , 314 , 550 , 281 , 2126 , 286 , 644 , 1611 , 286 , 983 , 314 , 2227 , 284 , 1205 , 11 , 11327 , 10787 , 532 , 1223 , 351 , 257 , 1256 , 286 , 5775 , 14 , 529 , 669 , 11 , 2829 , 9382 , 11 , 3863 , 900 , 287 , 2272 , 11 , 6856 , 422 , 257 , 1353 , 12 , 2902 , 1570 , 13 , 314 , 373 , 6563 , 326 , 314 , 714 , 4197 , 597 , 7505 , 1088 , 340 , 13 , 554 , 262 , 886 , 11 , 262 , 1917 , 351 , 257 , 7505 , 588 , 705 , 15200 , 2122 , 6 , 287 , 257 , 983 , 318 , 326 , 6954 , 318 , 555 , 42191 , 13 , 632 , 4325 , 832 , 1811 , 9775 , 4738 , 23005 , 625 , 640 , 11 , 351 , 262 , 749 , 15409 , 9943 , 7094 , 16997 , 13 , 770 , 8513 , 1097 , 35375 , 318 , 11 , 287 , 616 , 4459 , 11 , 257 , 1049 , 1672 , 286 , 4036 , 6954 , 286 , 257 , 4693 , 6476 , 257 , 4427 , 13 , 887 , 318 , 340 , 257 , 983 , 30 , 554 , 257 , 983 , 11 , 345 , 761 , 284 , 1630 , 1223 , 284 , 3151 , 281 , 9432 , 13 , 1320 , 1630 , 2925 , 1028 , 644 , 6954 , 318 , 4385 , 284 , 307 , 588 , 13 , 1002 , 345 , 1249 , 262 , 2836 , 284 , 2298 , 703 , 284 , 18101 , 1223 , 11 , 340 , 338 , 407 , 6954 , 7471 , 532 , 340 , 338 , 262 , 7548 , 286 , 12661 , 1486 , 11 , 262 , 277 , 540 , 15646 , 416 , 6282 , 1023 , 284 , 5249 , 262 , 2126 , 286 , 6954 , 13 , 11204 , 556 , 43758 , 290 , 257 , 11303 , 1878 , 3699 , 11 , 326 , 338 , 407 , 1223 , 326 , 31862 , 502 , 262 , 826 , 835 , 13 , 16227 , 11 , 616 , 4094 , 288 , 8270 , 2611 , 618 , 14615 , 644 , 284 , 2251 , 373 , 407 , 351 , 644 , 314 , 2227 , 284 , 2251 , 11 , 475 , 351 , 644 , 314 , 750 , 407 , 13 , 314 , 1422 , 470 , 765 , 284 , 2251 , 281 , 705 , 600 , 32940 , 1486 , 6 , 35375 , 290 , 31238 , 869 , 340 , 6954 , 13 , 770 , 318 , 257 , 1917 , 11 , 286 , 1781 , 11 , 790 , 584 , 44047 , 635 , 550 , 284 , 1986 , 13 , 843 , 22989 , 416 , 262 , 12784 , 8948 , 11 , 407 , 867 , 5257 , 284 , 670 , 1088 , 340 , 13 , 314 , 1549 , 910 , 262 , 691 , 1103 , 4610 , 373 , 832 , 262 , 779 , 286 , 11666 , 6356 , 11 , 7599 , 13 , 1406 , 1290 , 11 , 314 , 4398 , 470 , 1775 , 597 , 5726 , 1262 , 428 , 379 , 663 , 4755 , 11327 , 13 , 45315 , 11 , 428 , 318 , 655 , 257 , 1257 , 5449 , 290 , 706 , 257 , 981 , 314 , 3066 , 407 , 284 , 307 , 355 , 7646 , 351 , 262 , 983 , 2126 , 11 , 290 , 3142 , 3589 , 284 , 2298 , 4232 , 314 , 1807 , 561 , 670 , 503 , 13 , 2011 , 4238 , 2126 , 373 , 284 , 2251 , 1223 , 810 , 9265 , 3088 , 284 , 18101 , 284 , 257 , 1306 , 1241 , 11 , 475 , 550 , 617 , 1611 , 286 , 22156 , 2111 , 284 , 2245 , 606 , 422 , 1804 , 523 , 13 , 314 , 1611 , 286 , 550 , 428 , 2939 , 286 , 1692 , 15625 , 7348 , 287 , 2272 , 3371 , 257 , 937 , 21446 , 393 , 257 , 2272 , 5156 , 357 , 439 , 1912 , 287 , 5878 , 25 , 317 , 4687 , 28032 , 286 , 1781 , 8 , 475 , 314 , 3521 , 470 , 892 , 286 , 13206 , 357 , 961 , 25 , 2726 , 8 , 12933 , 329 , 326 , 13 , 29004 , 82 , 547 , 616 , 1306 , 12141 , 11 , 355 , 511 , 2187 , 14078 , 4197 , 2495 , 880 , 656 , 262 , 6954 , 7505 , 13 , 887 , 703 , 284 , 787 , 340 , 670 , 30 , 4231 , 345 , 262 , 275 , 2398 , 11 , 393 , 4330 , 262 , 29004 , 30 , 383 , 2368 , 290 , 2457 , 2126 , 1625 , 284 , 502 , 832 , 616 , 11077 , 11 , 508 , 7599 , 2921 , 502 , 262 , 2126 , 286 , 1642 , 1223 , 546 , 262 , 6954 , 286 , 11303 , 64 , 13 , 383 , 517 , 314 , 1807 , 546 , 340 , 262 , 517 , 340 , 14846 , 588 , 340 , 561 , 670 , 11 , 523 , 314 , 3066 , 284 , 467 , 351 , 340 , 13 , 32200 , 602 , 351 , 616 , 20886 , 763 , 12 , 28816 , 371 , 516 , 20342 , 357 , 8727 , 635 , 2727 , 262 , 705 , 28452 , 36684 , 4698 , 22242 , 6 , 9877 , 11112 , 329 , 616 , 493 , 4951 , 8 , 2252 , 48491 , 262 , 3721 , 11 , 355 , 340 , 2950 , 656 , 262 , 2126 , 286 , 1719 , 1981 , 5207 , 286 , 26296 , 7348 , 1088 , 290 , 2111 , 284 , 18101 , 1566 , 484 , 2627 , 477 , 12 , 44548 , 13 , 317 , 9233 , 2126 , 994 , 373 , 326 , 262 , 983 , 561 , 670 , 284 , 4727 , 703 , 262 , 19903 , 1338 , 35812 , 12635 , 1625 , 284 , 2152 , 532 , 416 , 21568 , 422 , 257 , 3487 , 8073 , 3084 , 13 , 1406 , 262 , 2126 , 12572 , 517 , 393 , 1342 , 656 , 428 , 25 , 345 , 389 , 5586 , 257 , 3084 , 13 , 921 , 423 , 534 , 898 , 7480 , 11 , 351 , 318 , 534 , 705 , 8692 , 4458 , 1318 , 389 , 642 , 584 , 10650 , 379 , 262 , 3084 , 11 , 1123 , 351 , 511 , 898 , 7480 , 13 , 3406 , 7480 , 460 , 10922 , 1310 , 5207 , 286 , 26296 , 13 , 921 , 466 , 523 , 416 , 705 , 34555 , 6 , 606 , 832 , 257 , 6859 , 13 , 2773 , 1613 , 292 , 389 , 1365 , 621 , 1854 , 26 , 617 , 389 , 5443 , 11 , 617 , 389 , 7387 , 13 , 1119 , 423 , 15874 , 705 , 15805 , 82 , 3256 , 543 , 389 , 1915 , 863 , 422 , 534 , 10824 , 357 , 5832 , 923 , 351 , 257 , 1271 , 286 , 10824 , 737 , 4874 , 29013 , 11 , 534 , 1613 , 292 , 923 , 7348 , 1088 , 13 , 5334 , 13311 , 318 , 284 , 6129 , 284 , 584 , 13375 , 11 , 287 , 1502 , 284 , 23875 , 606 , 357 , 1169 , 9432 , 286 , 262 , 983 , 318 , 1719 , 534 , 26296 , 23875 , 477 , 262 , 13375 , 319 , 262 , 3084 , 737 , 887 , 484 , 389 , 1107 , 18284 , 11 , 523 , 706 , 852 , 29013 , 11 , 345 , 423 , 645 , 1630 , 625 , 534 , 26296 , 357 , 14925 , 22875 , 32 , 393 , 6706 , 43 , 49100 , 737 , 3406 , 26296 , 1595 , 470 , 588 , 584 , 661 , 338 , 26296 , 11 , 523 , 611 , 484 , 1826 , 11 , 484 , 2686 , 10746 , 379 , 1123 , 584 , 1566 , 530 , 10564 , 13 , 921 , 651 , 10824 , 329 , 584 , 1613 , 292 , 534 , 898 , 26296 , 1494 , 13 , 4874 , 257 , 26296 , 318 , 287 , 262 , 25980 , 286 , 257 , 7480 , 11 , 340 , 4940 , 49977 , 340 , 329 , 663 , 1074 , 13 , 632 , 2753 , 1088 , 838 , 4201 , 329 , 257 , 7480 , 284 , 307 , 29346 , 26 , 1342 , 611 , 517 , 26296 , 422 , 262 , 976 , 1074 , 389 , 1088 , 13 , 1002 , 26296 , 422 , 584 , 1074 , 389 , 1088 , 11 , 996 , 11 , 484 , 651 , 8970 , 866 , 287 , 511 , 2230 , 11 , 5906 , 284 , 23875 , 262 , 7480 , 11 , 1566 , 530 , 286 , 606 , 4656 , 357 , 14925 , 30193 , 338 , 3210 , 705 , 3103 , 6138 , 6 , 4235 , 737 , 921 , 651 , 2173 , 790 , 1218 , 329 , 790 , 7480 , 345 , 898 , 13 , 3827 , 640 , 11 , 262 , 3721 , 12572 , 656 , 257 , 983 , 810 , 345 , 389 , 262 , 26296 , 11 , 290 , 345 , 389 , 2111 , 284 , 18101 , 284 , 1716 , 262 , 749 , 3665 , 26296 , 13 , 921 , 460 , 18101 , 416 , 6600 , 584 , 1613 , 292 , 11 , 393 , 416 , 5170 , 584 , 1613 , 292 , 13 , 921 , 460 , 635 , 18101 , 416 , 5170 , 584 , 1613 , 292 , 13 , 921 , 460 , 635 , 18101 , 416 , 5170 , 584 , 1613 , 292 , 13 , 921 , 460 , 635 , 18101 , 416 , 5170 , 584 , 1613 , 292 , 13 , 921 , 460 , 635 , 18101 , 416 , 5170 , 584 , 1613 , 292 , 13 , 921 , 460 , 635 , 18101 , 416 , 5170 , 584 , 1613 , 292 , 13 , 921 , 460 , 635 , 18101 , 416 , 5170 , 584 , 1613 , 292 , 13 , 921 , 460 , 635 , 18101 , 416 , 5170 , 584 , 1613 , 292 , 13 , 921 , 460 , 635 , 18101 , 416 , 5170 , 584 , 1613 , 292 , 13 , 921 , 460 , 635 , 18101 , 416 , 5170 , 584 , 1613 , 292 ]
141
150
else :
142
151
# "Over time, the concept evolved into this: you are sitting a table. You have your own plate, with is your 'base'. There are 5 other guests at the table, each with their own plate. Your plate can spawn little pieces of pasta. You do so by 'ordering' them through a menu. Some pastas are better than others; some are faster, some are stronger. They have varying 'costs', which are debited from your credits (you start with a number of credits). Once spawned, your pastas start flying around. Their instinct is to fly to other plates, in order to conquer them (the objective of the game is"
143
152
except_outputs = [1026 , 318 , 1760 , 11 , 290 , 8948 , 13 , 921 , 460 , 711 , 705 , 34652 , 2473 , 286 , 262 , 309 , 459 , 6386 , 6 , 319 , 5565 , 11 , 290 , 319 , 262 , 3992 , 13 , 23911 , 319 , 262 , 3992 , 2499 , 11 , 475 , 345 , 423 , 284 , 29308 , 3294 , 3638 , 329 , 3084 , 3867 , 290 , 326 , 460 , 307 , 257 , 1643 , 15337 , 13 , 1318 , 318 , 257 , 1256 , 314 , 1549 , 588 , 284 , 1561 , 546 , 13 , 314 , 481 , 467 , 832 , 790 , 7243 , 11 , 916 , 276 , 286 , 1642 , 262 , 7226 , 644 , 1816 , 826 , 14 , 36460 , 1351 , 13 , 26097 , 14594 , 625 , 262 , 7505 , 373 , 2192 , 530 , 286 , 262 , 17612 , 8861 , 543 , 314 , 550 , 284 , 1986 , 13 , 19486 , 11 , 314 , 550 , 281 , 2126 , 286 , 644 , 1611 , 286 , 983 , 314 , 2227 , 284 , 1205 , 11 , 11327 , 10787 , 532 , 1223 , 351 , 257 , 1256 , 286 , 5775 , 14 , 529 , 669 , 11 , 2829 , 9382 , 11 , 3863 , 900 , 287 , 2272 , 11 , 6856 , 422 , 257 , 1353 , 12 , 2902 , 1570 , 13 , 314 , 373 , 6563 , 326 , 314 , 714 , 4197 , 597 , 7505 , 1088 , 340 , 13 , 554 , 262 , 886 , 11 , 262 , 1917 , 351 , 257 , 7505 , 588 , 705 , 15200 , 2122 , 6 , 287 , 257 , 983 , 318 , 326 , 6954 , 318 , 555 , 42191 , 13 , 632 , 4325 , 832 , 1811 , 9775 , 4738 , 23005 , 625 , 640 , 11 , 351 , 262 , 749 , 15409 , 9943 , 7094 , 16997 , 13 , 770 , 8513 , 1097 , 35375 , 318 , 11 , 287 , 616 , 4459 , 11 , 257 , 1049 , 1672 , 286 , 4036 , 6954 , 286 , 257 , 4693 , 6476 , 257 , 4427 , 13 , 887 , 318 , 340 , 257 , 983 , 30 , 554 , 257 , 983 , 11 , 345 , 761 , 284 , 1630 , 1223 , 284 , 3151 , 281 , 9432 , 13 , 1320 , 1630 , 2925 , 1028 , 644 , 6954 , 318 , 4385 , 284 , 307 , 588 , 13 , 1002 , 345 , 1249 , 262 , 2836 , 284 , 2298 , 703 , 284 , 18101 , 1223 , 11 , 340 , 338 , 407 , 6954 , 7471 , 532 , 340 , 338 , 262 , 7548 , 286 , 12661 , 1486 , 11 , 262 , 277 , 540 , 15646 , 416 , 6282 , 1023 , 284 , 5249 , 262 , 2126 , 286 , 6954 , 13 , 11204 , 556 , 43758 , 290 , 257 , 11303 , 1878 , 3699 , 11 , 326 , 338 , 407 , 1223 , 326 , 31862 , 502 , 262 , 826 , 835 , 13 , 16227 , 11 , 616 , 4094 , 288 , 8270 , 2611 , 618 , 14615 , 644 , 284 , 2251 , 373 , 407 , 351 , 644 , 314 , 2227 , 284 , 2251 , 11 , 475 , 351 , 644 , 314 , 750 , 407 , 13 , 314 , 1422 , 470 , 765 , 284 , 2251 , 281 , 705 , 600 , 32940 , 1486 , 6 , 35375 , 290 , 31238 , 869 , 340 , 6954 , 13 , 770 , 318 , 257 , 1917 , 11 , 286 , 1781 , 11 , 790 , 584 , 44047 , 635 , 550 , 284 , 1986 , 13 , 843 , 22989 , 416 , 262 , 12784 , 8948 , 11 , 407 , 867 , 5257 , 284 , 670 , 1088 , 340 , 13 , 314 , 1549 , 910 , 262 , 691 , 1103 , 4610 , 373 , 832 , 262 , 779 , 286 , 11666 , 6356 , 11 , 7599 , 13 , 1406 , 1290 , 11 , 314 , 4398 , 470 , 1775 , 597 , 5726 , 1262 , 428 , 379 , 663 , 4755 , 11327 , 13 , 45315 , 11 , 428 , 318 , 655 , 257 , 1257 , 5449 , 290 , 706 , 257 , 981 , 314 , 3066 , 407 , 284 , 307 , 355 , 7646 , 351 , 262 , 983 , 2126 , 11 , 290 , 3142 , 3589 , 284 , 2298 , 4232 , 314 , 1807 , 561 , 670 , 503 , 13 , 2011 , 4238 , 2126 , 373 , 284 , 2251 , 1223 , 810 , 9265 , 3088 , 284 , 18101 , 284 , 257 , 1306 , 1241 , 11 , 475 , 550 , 617 , 1611 , 286 , 22156 , 2111 , 284 , 2245 , 606 , 422 , 1804 , 523 , 13 , 314 , 1611 , 286 , 550 , 428 , 2939 , 286 , 1692 , 15625 , 7348 , 287 , 2272 , 3371 , 257 , 937 , 21446 , 393 , 257 , 2272 , 5156 , 357 , 439 , 1912 , 287 , 5878 , 25 , 317 , 4687 , 28032 , 286 , 1781 , 8 , 475 , 314 , 3521 , 470 , 892 , 286 , 13206 , 357 , 961 , 25 , 2726 , 8 , 12933 , 329 , 326 , 13 , 29004 , 82 , 547 , 616 , 1306 , 12141 , 11 , 355 , 511 , 2187 , 14078 , 4197 , 2495 , 880 , 656 , 262 , 6954 , 7505 , 13 , 887 , 703 , 284 , 787 , 340 , 670 , 30 , 4231 , 345 , 262 , 275 , 2398 , 11 , 393 , 4330 , 262 , 29004 , 30 , 383 , 2368 , 290 , 2457 , 2126 , 1625 , 284 , 502 , 832 , 616 , 11077 , 11 , 508 , 7599 , 2921 , 502 , 262 , 2126 , 286 , 1642 , 1223 , 546 , 262 , 6954 , 286 , 11303 , 64 , 13 , 383 , 517 , 314 , 1807 , 546 , 340 , 262 , 517 , 340 , 14846 , 588 , 340 , 561 , 670 , 11 , 523 , 314 , 3066 , 284 , 467 , 351 , 340 , 13 , 32200 , 602 , 351 , 616 , 20886 , 763 , 12 , 28816 , 371 , 516 , 20342 , 357 , 8727 , 635 , 2727 , 262 , 705 , 28452 , 36684 , 4698 , 22242 , 6 , 9877 , 11112 , 329 , 616 , 493 , 4951 , 8 , 2252 , 48491 , 262 , 3721 , 11 , 355 , 340 , 2950 , 656 , 262 , 2126 , 286 , 1719 , 1981 , 5207 , 286 , 26296 , 7348 , 1088 , 290 , 2111 , 284 , 18101 , 1566 , 484 , 2627 , 477 , 12 , 44548 , 13 , 317 , 9233 , 2126 , 994 , 373 , 326 , 262 , 983 , 561 , 670 , 284 , 4727 , 703 , 262 , 19903 , 1338 , 35812 , 12635 , 1625 , 284 , 2152 , 532 , 416 , 21568 , 422 , 257 , 3487 , 8073 , 3084 , 13 , 1406 , 262 , 2126 , 12572 , 517 , 393 , 1342 , 656 , 428 , 25 , 345 , 389 , 5586 , 257 , 3084 , 13 , 921 , 423 , 534 , 898 , 7480 , 11 , 351 , 318 , 534 , 705 , 8692 , 4458 , 1318 , 389 , 642 , 584 , 10650 , 379 , 262 , 3084 , 11 , 1123 , 351 , 511 , 898 , 7480 , 13 , 3406 , 7480 , 460 , 10922 , 1310 , 5207 , 286 , 26296 , 13 , 921 , 466 , 523 , 416 , 705 , 34555 , 6 , 606 , 832 , 257 , 6859 , 13 , 2773 , 1613 , 292 , 389 , 1365 , 621 , 1854 , 26 , 617 , 389 , 5443 , 11 , 617 , 389 , 7387 , 13 , 1119 , 423 , 15874 , 705 , 15805 , 82 , 3256 , 543 , 389 , 1915 , 863 , 422 , 534 , 10824 , 357 , 5832 , 923 , 351 , 257 , 1271 , 286 , 10824 , 737 , 4874 , 29013 , 11 , 534 , 1613 , 292 , 923 , 7348 , 1088 , 13 , 5334 , 13311 , 318 , 284 , 6129 , 284 , 584 , 13375 , 11 , 287 , 1502 , 284 , 23875 , 606 , 357 , 1169 , 9432 , 286 , 262 , 983 , 318 , 1719 , 534 , 26296 , 23875 , 477 , 262 , 13375 , 319 , 262 , 3084 , 737 , 887 , 484 , 389 , 1107 , 18284 , 11 , 523 , 706 , 852 , 29013 , 11 , 345 , 423 , 645 , 1630 , 625 , 534 , 26296 , 357 , 14925 , 22875 , 32 , 393 , 6706 , 43 , 49100 , 737 , 3406 , 26296 , 1595 , 470 , 588 , 584 , 661 , 338 , 26296 , 11 , 523 , 611 , 484 , 1826 , 11 , 484 , 2686 , 10746 , 379 , 1123 , 584 , 1566 , 530 , 10564 , 13 , 921 , 651 , 10824 , 329 , 584 , 1613 , 292 , 534 , 898 , 26296 , 1494 , 13 , 4874 , 257 , 26296 , 318 , 287 , 262 , 25980 , 286 , 257 , 7480 , 11 , 340 , 4940 , 49977 , 340 , 329 , 663 , 1074 , 13 , 632 , 2753 , 1088 , 838 , 4201 , 329 , 257 , 7480 , 284 , 307 , 29346 , 26 , 1342 , 611 , 517 , 26296 , 422 , 262 , 976 , 1074 , 389 , 1088 , 13 , 1002 , 26296 , 422 , 584 , 1074 , 389 , 1088 , 11 , 996 , 11 , 484 , 651 , 8970 , 866 , 287 , 511 , 2230 , 11 , 5906 , 284 , 23875 , 262 , 7480 , 11 , 1566 , 530 , 286 , 606 , 4656 , 357 , 14925 , 30193 , 338 , 3210 , 705 , 3103 , 6138 , 6 , 4235 , 737 , 921 , 651 , 2173 , 790 , 1218 , 329 , 790 , 7480 , 345 , 898 , 13 , 3827 , 640 , 11 , 262 , 3721 , 12572 , 656 , 428 , 25 , 345 , 389 , 5586 , 257 , 3084 , 13 , 921 , 423 , 534 , 898 , 7480 , 11 , 351 , 318 , 534 , 705 , 8692 , 4458 , 1318 , 389 , 642 , 584 , 10650 , 379 , 262 , 3084 , 11 , 1123 , 351 , 511 , 898 , 7480 , 13 , 3406 , 7480 , 460 , 10922 , 1310 , 5207 , 286 , 26296 , 13 , 921 , 466 , 523 , 416 , 705 , 34555 , 6 , 606 , 832 , 257 , 6859 , 13 , 2773 , 1613 , 292 , 389 , 1365 , 621 , 1854 , 26 , 617 , 389 , 5443 , 11 , 617 , 389 , 7387 , 13 , 1119 , 423 , 15874 , 705 , 15805 , 82 , 3256 , 543 , 389 , 1915 , 863 , 422 , 534 , 10824 , 357 , 5832 , 923 , 351 , 257 , 1271 , 286 , 10824 , 737 , 4874 , 29013 , 11 , 534 , 1613 , 292 , 923 , 7348 , 1088 , 13 , 5334 , 13311 , 318 , 284 , 6129 , 284 , 584 , 13375 , 11 , 287 , 1502 , 284 , 23875 , 606 , 357 , 1169 , 9432 , 286 , 262 , 983 , 318 ]
144
- acc = (outputs == except_outputs ).sum ().item ()/ except_outputs .size
145
- print ("Accuracy = {}" .format (acc ))
153
+
154
+ if max_new_tokens == 1 :
155
+ index = int (input_tokens )
156
+ acc = (outputs [index ] == except_outputs [index ])
157
+ else :
158
+ acc = (outputs == except_outputs ).sum ().item ()/ len (except_outputs )
159
+ print ("Accuracy = {:.2f}" .format (acc ))
0 commit comments