Skip to content

Commit 08b2616

Browse files
authored
Merge pull request #917 from epfml/NAN-hellaswag-gpt-evaluation-christinakopi
Hellaswag GPT evaluation
2 parents 7f32856 + 419645b commit 08b2616

File tree

14 files changed

+587
-0
lines changed

14 files changed

+587
-0
lines changed

cli/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,26 @@ CLI options can be listed with `npm -w cli run benchmark_gpt -- -h`.
5252
To benchmark model training, you can run `npm -w cli run benchmark_gpt -- --modelType gpt-nano --contextLength 128 --batchSize 8`.
5353

5454
For inference run `npm -w cli run benchmark_gpt -- --inference --modelPath <path to trained model json file>`. You can use the `docs/example/wikitext` example script to train a model. The model needs to be trained on the wikitext default task to ensure that model parameters such as vocab size, tokenizer, max sequence length are the same between training and inference.
55+
56+
## Evaluating GPT Models on HellaSwag
57+
58+
The CLI includes a script to evaluate GPT models on the [HellaSwag](https://rowanzellers.com/hellaswag/) dataset, a common benchmark for evaluating commonsense reasoning in language models.
59+
60+
To run the evaluation: `npm -w cli run hellaswag_gpt`
61+
62+
The script benchmarks the following models:
63+
- A TensorFlow.js implementation of GPT (`gpt-tfjs`)
64+
- A pre-trained ONNX model (`Xenova/gpt2`)
65+
66+
Both models are evaluated using a shared tokenizer (`Xenova/gpt2`), and the script reports:
67+
- Accuracy (proportion of correct multiple-choice predictions)
68+
- Total evaluation time (in seconds)
69+
70+
### Output
71+
72+
Results are printed to the console and saved to a log file: `../datasets/logFile_hellaswag.txt`
73+
74+
75+
This allows for a direct comparison between the inference performance and accuracy of the two architectures.
76+
77+
The TFJS implementation is generally slower and more memory-intensive than ONNX, but offers compatibility with browser-based environments and custom training workflows. See the [Benchmarking GPT-TF.js](#benchmarking-gpt-tfjs) section for more details on performance tradeoffs.

cli/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"start": "npm run build && node dist/cli.js",
99
"benchmark_gpt": "npm run build && node dist/benchmark_gpt.js",
1010
"train_gpt": "npm run build && node dist/train_gpt.js",
11+
"hellaswag_gpt": "npm run build && node dist/hellaswag_gpt.js",
1112
"build": "tsc",
1213
"lint": "npx eslint .",
1314
"test": ": nothing"

cli/src/hellaswag_gpt.ts

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import '@tensorflow/tfjs-node';
2+
import { loadHellaSwag } from '@epfml/discojs-node';
3+
import { models } from '@epfml/discojs';
4+
import { AutoTokenizer, PreTrainedTokenizer } from '@xenova/transformers';
5+
import fs from 'fs';
6+
import path from 'node:path';
7+
8+
const logFile = path.join('..', 'datasets', 'LogFile_hellaswag.txt');
9+
const logLines: string[] = [];
10+
11+
function log(message: string) {
12+
console.log(message);
13+
logLines.push(message);
14+
}
15+
16+
const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(-1)
17+
18+
async function evaluateTFJS(tokenizer: PreTrainedTokenizer) {
19+
const model = new models.GPT({ seed: 42 });
20+
log('Evaluating TFJS GPT on HellaSwag...');
21+
22+
const start = Date.now();
23+
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false);
24+
const duration = ((Date.now() - start) / 1000).toFixed(2);
25+
26+
log(`TFJS GPT Accuracy: ${(accuracy * 100).toFixed(2)}%`);
27+
log(`TFJS GPT Evaluation Time: ${duration} seconds`);
28+
}
29+
30+
async function evaluateXenova(tokenizer: PreTrainedTokenizer) {
31+
const model = await models.ONNXModel.init_pretrained('Xenova/gpt2');
32+
log('Evaluating Xenova GPT-2 (ONNX) on HellaSwag...');
33+
34+
const start = Date.now();
35+
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false);
36+
const duration = ((Date.now() - start) / 1000).toFixed(2);
37+
38+
log(`Xenova GPT-2 Accuracy: ${(accuracy * 100).toFixed(2)}%`);
39+
log(`Xenova GPT-2 Evaluation Time: ${duration} seconds`);
40+
}
41+
42+
async function main(): Promise<void> {
43+
fs.writeFileSync(logFile, '', 'utf-8'); // Clear old log file
44+
45+
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2');
46+
await evaluateTFJS(tokenizer);
47+
log('\n---\n');
48+
await evaluateXenova(tokenizer);
49+
50+
fs.writeFileSync(logFile, logLines.join('\n'), 'utf-8');
51+
console.log(`\nResults written to ${logFile}`);
52+
}
53+
54+
main().catch(console.error);

discojs-node/src/hellaswag.spec.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import { expect } from 'chai';
2+
import { load as loadHellaSwag } from './hellaswag.js';
3+
4+
describe('HellaSwag parser', () => {
5+
it('should load all examples and return them as an array', async () => {
6+
const dataset = await loadHellaSwag(10);
7+
8+
expect(dataset).to.be.an('array');
9+
expect(dataset.length).to.be.greaterThan(0);
10+
11+
// Check the structure of the first example
12+
const example = dataset[0];
13+
expect(example).to.have.property('ctx').that.is.a('string');
14+
expect(example).to.have.property('endings').that.is.an('array').with.lengthOf(4);
15+
expect(example).to.have.property('label').that.is.a('number');
16+
});
17+
});
18+

discojs-node/src/hellaswag.ts

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import { models } from '@epfml/discojs';
2+
import fetch from 'node-fetch';
3+
4+
/**
5+
* Loads the HellaSwag dataset from the remote URL in Node.js
6+
*
7+
* @param limit - Maximum number of examples to load (-1 means all)
8+
* @returns A HellaSwagDataset containing the examples.
9+
*/
10+
export async function load(limit = -1): Promise<models.HellaSwagDataset> {
11+
const response = await fetch(models.HELLASWAG_URL);
12+
if (!response.ok) {
13+
throw new Error(`Failed to fetch dataset from ${models.HELLASWAG_URL}: ${response.statusText}`);
14+
}
15+
16+
const text = await response.text();
17+
const lines = text.split('\n');
18+
19+
const dataset: models.HellaSwagDataset = [];
20+
let count = 0;
21+
for (const line of lines) {
22+
if (line.trim().length === 0) continue;
23+
if (limit !== -1 && count >= limit) break;
24+
25+
try {
26+
const data = JSON.parse(line.trim()) as models.HellaSwagExample;
27+
dataset.push(data);
28+
count++;
29+
} catch (e) {
30+
console.error(`Failed to parse line:`, line);
31+
throw e;
32+
}
33+
}
34+
35+
return dataset;
36+
}

discojs-node/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
export * from './loaders/index.js'
22
export { saveModelToDisk, loadModelFromDisk } from './model_loader.js'
3+
export { load as loadHellaSwag } from './hellaswag.js'

discojs-web/src/hellaswag.spec.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import { describe, it, expect } from "vitest";
2+
import { load as loadHellaSwag } from './hellaswag.js';
3+
import { models } from '@epfml/discojs';
4+
5+
describe('hellaswag parser', () => {
6+
it('loads the whole hellaswag dataset', async () => {
7+
const dataset: models.HellaSwagDataset = await loadHellaSwag(2);
8+
9+
// basic assertions
10+
expect(dataset).to.be.an('array');
11+
expect(dataset.length).to.equal(2);
12+
13+
// check structure of the first example
14+
const first = dataset[0];
15+
expect(first).to.have.property('ctx').that.is.a('string');
16+
expect(first).to.have.property('endings').that.is.an('array').with.lengthOf(4);
17+
expect(first).to.have.property('label').that.is.a('number');
18+
});
19+
});

discojs-web/src/hellaswag.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import { models } from '@epfml/discojs';
2+
3+
/**
4+
* Loads the HellaSwag dataset from the remote URL in the browser
5+
*
6+
* @param limit - Maximum number of examples to load (-1 means all)
7+
* @returns A HellaSwagDataset containing the examples
8+
*/
9+
export async function load(limit = -1): Promise<models.HellaSwagDataset> {
10+
const response = await fetch(models.HELLASWAG_URL);
11+
if (!response.ok) {
12+
throw new Error(`Failed to fetch dataset from ${models.HELLASWAG_URL}: ${response.statusText}`);
13+
}
14+
15+
const text = await response.text();
16+
const lines = text.split('\n');
17+
18+
const dataset: models.HellaSwagDataset = [];
19+
let count = 0;
20+
for (const line of lines) {
21+
if (line.trim().length === 0) continue;
22+
if (limit !== -1 && count >= limit) break;
23+
24+
try {
25+
const data = JSON.parse(line.trim()) as models.HellaSwagExample;
26+
dataset.push(data);
27+
count++;
28+
} catch (e) {
29+
console.error(`Failed to parse line:`, line);
30+
throw e;
31+
}
32+
}
33+
34+
return dataset;
35+
}

discojs-web/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
export * from "./loaders/index.js";
2+
export { load as loadHellaSwag } from "./hellaswag.js";
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import { expect } from 'chai';
2+
import { evaluate } from './hellaswag.js';
3+
import { PreTrainedTokenizer } from '@xenova/transformers';
4+
import { GPT } from './index.js';
5+
import { ONNXModel } from './onnx.js';
6+
import type { HellaSwagExample } from './hellaswag.js';
7+
8+
export const exampleDataset: HellaSwagExample[] = [
9+
{
10+
ctx: "A man is sitting on a roof. he",
11+
endings: [
12+
"is using wrap to wrap a pair of skis.",
13+
"is ripping level tiles off.",
14+
"is holding a rubik's cube.",
15+
"starts pulling up roofing on a roof."
16+
],
17+
label: 3
18+
},
19+
{
20+
ctx: "A lady walks to a barbell. She bends down and grabs the pole. the lady",
21+
endings: [
22+
"swings and lands in her arms.",
23+
"pulls the barbell forward.",
24+
"pulls a rope attached to the barbell.",
25+
"stands and lifts the weight over her head."
26+
],
27+
label: 3
28+
}
29+
];
30+
31+
describe('HellaSwag Evaluator', () => {
32+
it('evaluates tfjs GPT model', async () => {
33+
const tokenizer = await PreTrainedTokenizer.from_pretrained('Xenova/gpt2');
34+
const gpt = new GPT({seed: 42,}); // seed for reproducibility
35+
36+
const accuracy = await evaluate(gpt, tokenizer, exampleDataset, true);
37+
expect(accuracy).to.be.gte(0);
38+
expect(accuracy).to.be.lte(1);
39+
}).timeout(6000);
40+
});
41+
42+
describe('HellaSwag Evaluator with Xenova GPT-2', () => {
43+
it('evaluates the pretrained GPT-2 model', async () => {
44+
const tokenizer = await PreTrainedTokenizer.from_pretrained('Xenova/gpt2');
45+
const model = await ONNXModel.init_pretrained('Xenova/gpt2');
46+
47+
const accuracy = await evaluate(model, tokenizer, exampleDataset, true);
48+
expect(accuracy).to.be.gte(0);
49+
expect(accuracy).to.be.lte(1);
50+
}).timeout(10000);
51+
});
52+
53+
describe('Deterministic evaluation with tfjs GPT-2', () => {
54+
it('returns the same accuracy across runs', async () => {
55+
const tokenizer = await PreTrainedTokenizer.from_pretrained('Xenova/gpt2');
56+
const gpt = new GPT({seed: 42,});
57+
58+
const accuracy1 = await evaluate(gpt, tokenizer, exampleDataset, false);
59+
const accuracy2 = await evaluate(gpt, tokenizer, exampleDataset, false);
60+
61+
expect(accuracy1).to.equal(accuracy2);
62+
}).timeout(10000);
63+
});

0 commit comments

Comments
 (0)