Skip to content

Commit 181145a

Browse files
author
Lu Teng
authored
[Example] Add accuracy check for GPT-J with different layer num (#434)
1 parent 8a3c857 commit 181145a

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

example/gptj/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pip install -r ../../test/requirements.txt
2222
| *```--max-new-tokens```*| *```32```*| Output max new tokens. |
2323
| *```--greedy```*| *```False```*| Enable greedy search or beam search. |
2424
| *```--num-iter```*| *```10```*| Number of iterations. |
25+
| *```--num-layer```*| *```28```*| Number of hidden layers. |
2526
| *```--num-warmup```*| *```3```*| Number of warmup iterations. |
2627
| *```--accuracy```*| *```False```*| Run accuracy check. |
2728

@@ -49,6 +50,7 @@ python jax_gptj.py --input-tokens 1024 --max-new-tokens 128
4950
```
5051

5152
### Performance Output
53+
5254
```
5355
Inference latency: x.xxx sec.
5456
Inference throughput: x.xxx samples/sec.
@@ -67,3 +69,11 @@ Inference throughput: x.xxx samples/sec.
6769
6870
Accuracy = 1.00
6971
```
72+
73+
### Test with less memory
74+
75+
Set option `--num-layer` (default value: `28`) to a small number, to reduce the memory footprint for test.
76+
```bash
77+
export ZE_AFFINITY_MASK=0
78+
python jax_gptj.py --input-tokens 1024 --max-new-tokens 128 --accuracy --num-layer 14
79+
```

example/gptj/jax_gptj.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
parser.add_argument("--input-tokens", default="32", type=str)
4646
parser.add_argument("--prompt", default=None, type=str)
4747
parser.add_argument("--num-iter", default=10, type=int, help="num iter")
48+
parser.add_argument("--num-layer", default=28, type=int, help="num hidden layers")
4849
parser.add_argument("--num-warmup", default=3, type=int, help="num warmup")
4950
parser.add_argument("--batch-size", default=1, type=int, help="batch size")
5051
parser.add_argument("--accuracy", action="store_true")
@@ -61,6 +62,8 @@
6162
model.params = model.to_bf16(model.params)
6263
else:
6364
model = FlaxGPTJForCausalLM.from_pretrained(model_id, dtype=jax.numpy.float32)
65+
model.config.n_layer = args.num_layer
66+
print(model.config)
6467
print("---- Model loading done", flush=True)
6568

6669
# input_ids
@@ -110,6 +113,8 @@ def run_model(input_ids):
110113
print("Inference throughput: %.3f samples/sec.\n" % (args.batch_size / latency), flush=True)
111114

112115
if args.accuracy:
116+
index = int(input_tokens)
117+
n_layer = args.num_layer
113118
invalid_flag = False
114119
accuracy_check_combination = [['32', 32], ['32', 1], ['1024', 128], ['1024', 1]]
115120

@@ -119,9 +124,13 @@ def run_model(input_ids):
119124
if [input_tokens, max_new_tokens] not in accuracy_check_combination:
120125
invalid_flag = True
121126

127+
if not (n_layer == 28 or (n_layer == 14 and index == 1024)):
128+
invalid_flag = True
129+
122130
if invalid_flag:
123-
print((num_beams, input_tokens, max_new_tokens), "is not a supported combination " \
124-
"[num_beams, input_tokens, max_new_tokens] for checking accuracy")
131+
print((num_beams, input_tokens, max_new_tokens, n_layer), "is not a " \
132+
"supported combination [num_beams, input_tokens, max_new_tokens, " \
133+
"num_layer] for checking accuracy")
125134
exit(-1)
126135

127136
gen_tokens = run_model(input_ids)
@@ -151,9 +160,22 @@ def run_model(input_ids):
151160
# "Over time, the concept evolved into this: you are sitting a table. You have your own plate, with is your 'base'. There are 5 other guests at the table, each with their own plate. Your plate can spawn little pieces of pasta. You do so by 'ordering' them through a menu. Some pastas are better than others; some are faster, some are stronger. They have varying 'costs', which are debited from your credits (you start with a number of credits). Once spawned, your pastas start flying around. Their instinct is to fly to other plates, in order to conquer them (the objective of the game is"
152161
except_outputs = [1026, 318, 1760, 11, 290, 8948, 13, 921, 460, 711, 705, 34652, 2473, 286, 262, 309, 459, 6386, 6, 319, 5565, 11, 290, 319, 262, 3992, 13, 23911, 319, 262, 3992, 2499, 11, 475, 345, 423, 284, 29308, 3294, 3638, 329, 3084, 3867, 290, 326, 460, 307, 257, 1643, 15337, 13, 1318, 318, 257, 1256, 314, 1549, 588, 284, 1561, 546, 13, 314, 481, 467, 832, 790, 7243, 11, 916, 276, 286, 1642, 262, 7226, 644, 1816, 826, 14, 36460, 1351, 13, 26097, 14594, 625, 262, 7505, 373, 2192, 530, 286, 262, 17612, 8861, 543, 314, 550, 284, 1986, 13, 19486, 11, 314, 550, 281, 2126, 286, 644, 1611, 286, 983, 314, 2227, 284, 1205, 11, 11327, 10787, 532, 1223, 351, 257, 1256, 286, 5775, 14, 529, 669, 11, 2829, 9382, 11, 3863, 900, 287, 2272, 11, 6856, 422, 257, 1353, 12, 2902, 1570, 13, 314, 373, 6563, 326, 314, 714, 4197, 597, 7505, 1088, 340, 13, 554, 262, 886, 11, 262, 1917, 351, 257, 7505, 588, 705, 15200, 2122, 6, 287, 257, 983, 318, 326, 6954, 318, 555, 42191, 13, 632, 4325, 832, 1811, 9775, 4738, 23005, 625, 640, 11, 351, 262, 749, 15409, 9943, 7094, 16997, 13, 770, 8513, 1097, 35375, 318, 11, 287, 616, 4459, 11, 257, 1049, 1672, 286, 4036, 6954, 286, 257, 4693, 6476, 257, 4427, 13, 887, 318, 340, 257, 983, 30, 554, 257, 983, 11, 345, 761, 284, 1630, 1223, 284, 3151, 281, 9432, 13, 1320, 1630, 2925, 1028, 644, 6954, 318, 4385, 284, 307, 588, 13, 1002, 345, 1249, 262, 2836, 284, 2298, 703, 284, 18101, 1223, 11, 340, 338, 407, 6954, 7471, 532, 340, 338, 262, 7548, 286, 12661, 1486, 11, 262, 277, 540, 15646, 416, 6282, 1023, 284, 5249, 262, 2126, 286, 6954, 13, 11204, 556, 43758, 290, 257, 11303, 1878, 3699, 11, 326, 338, 407, 1223, 326, 31862, 502, 262, 826, 835, 13, 16227, 11, 616, 4094, 288, 8270, 2611, 618, 14615, 644, 284, 2251, 373, 407, 351, 644, 314, 2227, 284, 2251, 11, 475, 351, 644, 314, 750, 407, 13, 314, 1422, 470, 765, 284, 2251, 281, 705, 600, 32940, 1486, 6, 35375, 290, 31238, 869, 340, 6954, 13, 770, 318, 257, 1917, 11, 286, 1781, 11, 790, 584, 44047, 635, 550, 284, 1986, 13, 843, 22989, 416, 262, 12784, 8948, 11, 407, 867, 5257, 284, 670, 1088, 340, 13, 314, 1549, 910, 262, 691, 1103, 4610, 373, 832, 262, 779, 286, 11666, 6356, 11, 7599, 13, 1406, 1290, 11, 314, 4398, 470, 1775, 597, 5726, 1262, 428, 379, 663, 4755, 11327, 13, 45315, 11, 428, 318, 655, 257, 1257, 5449, 290, 706, 257, 981, 314, 3066, 407, 284, 307, 355, 7646, 351, 262, 983, 2126, 11, 290, 3142, 3589, 284, 2298, 4232, 314, 1807, 561, 670, 503, 13, 2011, 4238, 2126, 373, 284, 2251, 1223, 810, 9265, 3088, 284, 18101, 284, 257, 1306, 1241, 11, 475, 550, 617, 1611, 286, 22156, 2111, 284, 2245, 606, 422, 1804, 523, 13, 314, 1611, 286, 550, 428, 2939, 286, 1692, 15625, 7348, 287, 2272, 3371, 257, 937, 21446, 393, 257, 2272, 5156, 357, 439, 1912, 287, 5878, 25, 317, 4687, 28032, 286, 1781, 8, 475, 314, 3521, 470, 892, 286, 13206, 357, 961, 25, 2726, 8, 12933, 329, 326, 13, 29004, 82, 547, 616, 1306, 12141, 11, 355, 511, 2187, 14078, 4197, 2495, 880, 656, 262, 6954, 7505, 13, 887, 703, 284, 787, 340, 670, 30, 4231, 345, 262, 275, 2398, 11, 393, 4330, 262, 29004, 30, 383, 2368, 290, 2457, 2126, 1625, 284, 502, 832, 616, 11077, 11, 508, 7599, 2921, 502, 262, 2126, 286, 1642, 1223, 546, 262, 6954, 286, 11303, 64, 13, 383, 517, 314, 1807, 546, 340, 262, 517, 340, 14846, 588, 340, 561, 670, 11, 523, 314, 3066, 284, 467, 351, 340, 13, 32200, 602, 351, 616, 20886, 763, 12, 28816, 371, 516, 20342, 357, 8727, 635, 2727, 262, 705, 28452, 36684, 4698, 22242, 6, 9877, 11112, 329, 616, 493, 4951, 8, 2252, 48491, 262, 3721, 11, 355, 340, 2950, 656, 262, 2126, 286, 1719, 1981, 5207, 286, 26296, 7348, 1088, 290, 2111, 284, 18101, 1566, 484, 2627, 477, 12, 44548, 13, 317, 9233, 2126, 994, 373, 326, 262, 983, 561, 670, 284, 4727, 703, 262, 19903, 1338, 35812, 12635, 1625, 284, 2152, 532, 416, 21568, 422, 257, 3487, 8073, 3084, 13, 1406, 262, 2126, 12572, 517, 393, 1342, 656, 428, 25, 345, 389, 5586, 257, 3084, 13, 921, 423, 534, 898, 7480, 11, 351, 318, 534, 705, 8692, 4458, 1318, 389, 642, 584, 10650, 379, 262, 3084, 11, 1123, 351, 511, 898, 7480, 13, 3406, 7480, 460, 10922, 1310, 5207, 286, 26296, 13, 921, 466, 523, 416, 705, 34555, 6, 606, 832, 257, 6859, 13, 2773, 1613, 292, 389, 1365, 621, 1854, 26, 617, 389, 5443, 11, 617, 389, 7387, 13, 1119, 423, 15874, 705, 15805, 82, 3256, 543, 389, 1915, 863, 422, 534, 10824, 357, 5832, 923, 351, 257, 1271, 286, 10824, 737, 4874, 29013, 11, 534, 1613, 292, 923, 7348, 1088, 13, 5334, 13311, 318, 284, 6129, 284, 584, 13375, 11, 287, 1502, 284, 23875, 606, 357, 1169, 9432, 286, 262, 983, 318, 1719, 534, 26296, 23875, 477, 262, 13375, 319, 262, 3084, 737, 887, 484, 389, 1107, 18284, 11, 523, 706, 852, 29013, 11, 345, 423, 645, 1630, 625, 534, 26296, 357, 14925, 22875, 32, 393, 6706, 43, 49100, 737, 3406, 26296, 1595, 470, 588, 584, 661, 338, 26296, 11, 523, 611, 484, 1826, 11, 484, 2686, 10746, 379, 1123, 584, 1566, 530, 10564, 13, 921, 651, 10824, 329, 584, 1613, 292, 534, 898, 26296, 1494, 13, 4874, 257, 26296, 318, 287, 262, 25980, 286, 257, 7480, 11, 340, 4940, 49977, 340, 329, 663, 1074, 13, 632, 2753, 1088, 838, 4201, 329, 257, 7480, 284, 307, 29346, 26, 1342, 611, 517, 26296, 422, 262, 976, 1074, 389, 1088, 13, 1002, 26296, 422, 584, 1074, 389, 1088, 11, 996, 11, 484, 651, 8970, 866, 287, 511, 2230, 11, 5906, 284, 23875, 262, 7480, 11, 1566, 530, 286, 606, 4656, 357, 14925, 30193, 338, 3210, 705, 3103, 6138, 6, 4235, 737, 921, 651, 2173, 790, 1218, 329, 790, 7480, 345, 898, 13, 3827, 640, 11, 262, 3721, 12572, 656, 428, 25, 345, 389, 5586, 257, 3084, 13, 921, 423, 534, 898, 7480, 11, 351, 318, 534, 705, 8692, 4458, 1318, 389, 642, 584, 10650, 379, 262, 3084, 11, 1123, 351, 511, 898, 7480, 13, 3406, 7480, 460, 10922, 1310, 5207, 286, 26296, 13, 921, 466, 523, 416, 705, 34555, 6, 606, 832, 257, 6859, 13, 2773, 1613, 292, 389, 1365, 621, 1854, 26, 617, 389, 5443, 11, 617, 389, 7387, 13, 1119, 423, 15874, 705, 15805, 82, 3256, 543, 389, 1915, 863, 422, 534, 10824, 357, 5832, 923, 351, 257, 1271, 286, 10824, 737, 4874, 29013, 11, 534, 1613, 292, 923, 7348, 1088, 13, 5334, 13311, 318, 284, 6129, 284, 584, 13375, 11, 287, 1502, 284, 23875, 606, 357, 1169, 9432, 286, 262, 983, 318]
153162

163+
if n_layer == 14:
164+
except_outputs[index:] = [
165+
935, 24043, 2196, 286, 428, 2126, 2627, 6481, 8253, 306, 47370, 306,
166+
47370, 306, 986, 5832, 821, 8066, 651, 17533, 6949, 986, 22850, 11393,
167+
30, 20525, 11, 356, 821, 8066, 1561, 546, 23514, 3307, 3307, 986,
168+
5832, 821, 8066, 651, 17533, 6949, 986, 22850, 11393, 30, 6521, 11,
169+
356, 821, 8066, 1561, 546, 23514, 3307, 986, 5832, 821, 8066, 651,
170+
17533, 6949, 986, 22850, 11393, 30, 6521, 11, 356, 821, 8066, 1561,
171+
546, 23514, 3307, 986, 5832, 821, 8066, 651, 17533, 6949, 986, 22850,
172+
11393, 30, 6521, 11, 356, 821, 8066, 1561, 546, 23514, 3307, 986,
173+
5832, 821, 8066, 651, 17533, 6949, 986, 22850, 11393, 30, 6521, 11,
174+
356, 821, 8066, 1561, 546, 23514, 3307, 986, 5832, 821, 8066, 651,
175+
17533, 6949, 986, 22850, 11393, 30, 9461, 11]
176+
154177
if max_new_tokens == 1:
155-
index = int(input_tokens)
156178
acc = (outputs[index] == except_outputs[index])
157179
else:
158-
acc = (outputs == except_outputs).sum().item()/len(except_outputs)
180+
acc = (outputs[index:] == except_outputs[index:]).sum().item()/max_new_tokens
159181
print("Accuracy = {:.2f}".format(acc))

0 commit comments

Comments
 (0)