Skip to content

Commit 7d7e20a

Browse files
committed
onnx-converter: new npm workspace to convert GPT2 from ONNX to TFJS
1 parent 55538d7 commit 7d7e20a

File tree

17 files changed

+10410
-43
lines changed

17 files changed

+10410
-43
lines changed

cli/src/hellaswag_gpt.ts

Lines changed: 79 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,104 @@
1+
// import fs from 'fs';
2+
import fsPromise from 'node:fs/promises';
3+
4+
import { dirname } from 'path';
5+
import { fileURLToPath } from 'url';
6+
import { parse } from 'ts-command-line-args'
7+
18
import '@tensorflow/tfjs-node';
29
import fs from 'node:fs';
310
import path from 'node:path';
4-
import { Tokenizer, models } from '@epfml/discojs';
11+
import { models, serialization, Tokenizer } from '@epfml/discojs';
512
import { loadHellaSwag } from '@epfml/discojs-node';
13+
// import { AutoTokenizer } from '@xenova/transformers';
614

7-
const logFile = path.join('..', 'datasets', 'LogFile_hellaswag.txt');
8-
const logLines: string[] = [];
15+
const __dirname = dirname(fileURLToPath(import.meta.url));
916

17+
const logLines: string[] = [];
1018
function log(message: string) {
1119
console.log(message);
1220
logLines.push(message);
1321
}
1422

15-
const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(-1)
16-
17-
async function evaluateTFJS(tokenizer: Tokenizer) {
18-
const model = new models.GPT({ seed: 42 });
19-
log('Evaluating TFJS GPT on HellaSwag...');
23+
async function evaluateModel(model: models.GPT | models.ONNXModel, numDataPoints = -1) {
24+
const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(numDataPoints)
25+
const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2');
26+
log('Starting the HellaSwag benchmark...');
2027

2128
const start = Date.now();
22-
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false);
29+
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, true);
2330
const duration = ((Date.now() - start) / 1000).toFixed(2);
2431

25-
log(`TFJS GPT Accuracy: ${(accuracy * 100).toFixed(2)}%`);
26-
log(`TFJS GPT Evaluation Time: ${duration} seconds`);
32+
log(`Final accuracy: ${(accuracy * 100).toFixed(2)}%`);
33+
log(`Evaluation Time: ${duration} seconds`);
2734
}
2835

29-
async function evaluateXenova(tokenizer: Tokenizer) {
30-
const model = await models.ONNXModel.init_pretrained('Xenova/gpt2');
31-
log('Evaluating Xenova GPT-2 (ONNX) on HellaSwag...');
36+
const ModelTypes = ['onnx', 'gpt-tfjs-random', 'gpt-tfjs-pretrained'] as const;
37+
type ModelType = typeof ModelTypes[number];
3238

33-
const start = Date.now();
34-
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false);
35-
const duration = ((Date.now() - start) / 1000).toFixed(2);
36-
37-
log(`Xenova GPT-2 Accuracy: ${(accuracy * 100).toFixed(2)}%`);
38-
log(`Xenova GPT-2 Evaluation Time: ${duration} seconds`);
39+
interface HellaSwagArgs {
40+
model: ModelType
41+
numDataPoints: number
42+
logFile: string
43+
pretrainedModelPath: string
44+
help?: boolean
3945
}
4046

4147
async function main(): Promise<void> {
42-
fs.writeFileSync(logFile, '', 'utf-8'); // Clear old log file
48+
const defaultPretrainedModelPath = path.join(__dirname, "..", "..", "onnx-converter", "assets", "model.json")
49+
const args = parse<HellaSwagArgs>({
50+
model: {
51+
type: (raw: string) => raw as ModelType,
52+
description: `Model type, one of ${ModelTypes}`,
53+
defaultValue: 'onnx'
54+
},
55+
numDataPoints: {
56+
type: Number,
57+
description: 'Number of HellaSwag datapoints to evaluate, set -1 for the whole benchmark',
58+
defaultValue: -1
59+
},
60+
logFile: {
61+
type: String,
62+
description: 'Relative path to the log file, default to ./hellaswag.log', defaultValue: 'hellaswag.log'
63+
},
64+
pretrainedModelPath: {
65+
type: String,
66+
description: 'If specifying gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model',
67+
defaultValue: defaultPretrainedModelPath
68+
},
69+
help: {
70+
type: Boolean,
71+
optional: true,
72+
alias: 'h',
73+
description: 'Prints this usage guide'
74+
}
75+
}, { helpArg: 'help' })
4376

44-
const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2');
45-
await evaluateTFJS(tokenizer);
46-
log('\n---\n');
47-
await evaluateXenova(tokenizer);
77+
const logFile = path.join(__dirname, args.logFile);
78+
fs.writeFileSync(logFile, '', 'utf-8'); // Clear the log file
79+
80+
let model: | models.GPT | models.ONNXModel | undefined;
81+
switch (args.model) {
82+
case 'onnx':
83+
log("Using ONNX pretrained model Xenova/gpt2")
84+
model = await models.ONNXModel.init_pretrained('Xenova/gpt2');
85+
break;
86+
case 'gpt-tfjs-random':
87+
log("Using GPT-TFJS with random initialization")
88+
model = new models.GPT({ seed: 42 });
89+
break;
90+
case 'gpt-tfjs-pretrained':
91+
log("Using GPT-TFJS with pretrained weights")
92+
if (args.pretrainedModelPath === undefined) {
93+
throw new Error("If choosing gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model `pretrainedModelPath")
94+
}
95+
const encodedModel = await fsPromise.readFile(args.pretrainedModelPath);
96+
model = await serialization.model.decode(encodedModel) as models.GPT;
97+
break;
98+
default:
99+
throw new Error(`Unrecognized model type: ${model}`);
100+
}
101+
await evaluateModel(model, args.numDataPoints);
48102

49103
fs.writeFileSync(logFile, logLines.join('\n'), 'utf-8');
50104
console.log(`\nResults written to ${logFile}`);

datasets/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@
2020

2121
# GDHF demo
2222
/tinder_dog/
23+
24+
# HellaSwag benchmark
25+
hellaswag*

discojs/src/models/gpt/layers.spec.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,9 @@ describe('GPT Layers', () => {
174174
name: 'testCSA',
175175
contextLength: 5,
176176
nHead: 2,
177-
nEmbd: 8, // divisible by nHead, so head size = 4
178-
dropout: 0.0, // no dropout for deterministic tests
177+
nEmbd: 8, // divisible by nHead, so head size = 4
178+
attnDrop: 0.0, // no dropout for deterministic tests
179+
residDrop: 0.0,
179180
nLayer: 2,
180181
seed: 42
181182
};

discojs/src/models/hellaswag.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ type ModelType = GPT | ONNXModel;
126126
export async function evaluate(
127127
model: ModelType,
128128
tokenizer: Tokenizer,
129-
dataset: HellaSwagExample[],
129+
dataset: HellaSwagDataset,
130130
print = true
131131
): Promise<number> {
132132
let correct = 0;

eslint.config.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ export default defineConfigWithVueTs(
6565
},
6666
{ ignores: ["**/dist/*"] },
6767
{ ignores: ["docs/examples/**"] },
68+
{ ignores: ["**/src/protobuf/"] },
69+
6870
// don't use linter for formatting
6971
skipFormatting,
7072
);

onnx-converter/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
node_modules
2+
assets
3+
dist

onnx-converter/README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
## Usage
2+
3+
This workspace is currently used to convert ONNX [GPT-2 model](https://huggingface.co/Xenova/gpt2) to Tensorflow.js. On the one hand, ONNX allows converting pretrained models from PyTorch or Tensorflow to the ONNX format, therefore there currently exists many pretrained models in ONNX format. However, ONNX libraries currently only support inference. On the other hand, Tensorflow.js doesn't have a converter that can handle recent Transformers models (despite having a [converter](https://github.com/tensorflow/tfjs/tree/master/tfjs-converter)), but TF.js allows further training models.
4+
5+
Therefore, we want to convert pretrained models such as GPT-2 from ONNX format to Tensorflow.js to further fine-tune them. You generate a TF.js `model.json` by running `npm run convert_onnx` in this workspace.
6+
7+
What the script does is:
8+
1. Read the ONNX GPT-2 model from [Xenova's repository](https://huggingface.co/Xenova/gpt2)
9+
2. Use the ONNX protobuf definition to read the file and iterate through the model layers. The ONNX JavaScript protobuf comes from [this repository](https://github.com/microsoft/onnxruntime/blob/main/js/web/lib/onnxjs/).
10+
3. Convert all weights to TF.js tensors
11+
4. Init a TF.js model with the loaded weights and export the model
12+
13+
Running `npm run convert_onnx` creates a GPT-tfjs `model.json` file in the `./assets/` folder.
14+
15+
## ONNX JS protobuf
16+
17+
The ONNX specification has limited support in JavaScript. We found an old JS implementation in the [ONNX Runtime Web repository](https://github.com/microsoft/onnxruntime/tree/main/js/web/lib/onnxjs/ort-schema/protobuf). We had to adapt their files as follows to be compatible with our newer environment:
18+
1. Copy `onnx.js` and `onnx.d.ts` from [the repository](https://github.com/microsoft/onnxruntime/tree/main/js/web/lib/onnxjs/ort-schema/protobuf) in `./onnx-converter/src/protobuf`
19+
2. Rename `onnx.js` to `onnx.cjs`
20+
3. Create `onnx-proto.js` as a wrapper around the protobuf definition:
21+
```js
22+
import { createRequire } from 'module';
23+
const require = createRequire(import.meta.url);
24+
const onnxModule = require('./onnx.cjs');
25+
26+
export const onnx = onnxModule.onnx;
27+
export default onnxModule;
28+
```
29+
4. Create `onnx-proto.d.ts` with the matching TypeScript definition:
30+
```ts
31+
export { onnx } from './onnx.js';
32+
declare const onnxModule: {
33+
onnx: typeof import('./onnx.js').onnx;
34+
};
35+
export default onnxModule;
36+
```

onnx-converter/package.json

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"name": "onnx-converter",
3+
"private": true,
4+
"type": "module",
5+
"main": "dist/gpt2_from_onnx.js",
6+
"scripts": {
7+
"convert_onnx": "npm run build && node dist/convert_onnx.js",
8+
"build": "tsc && cp -r src/protobuf dist",
9+
"lint": "npx eslint .",
10+
"test": ": nothing"
11+
},
12+
"author": "",
13+
"license": "ISC",
14+
"dependencies": {
15+
"@epfml/discojs-node": "*"
16+
},
17+
"devDependencies": {
18+
"nodemon": "3",
19+
"ts-command-line-args": "2"
20+
}
21+
}

onnx-converter/src/convert_onnx.ts

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import { onnx } from './protobuf/onnx-proto.js';
2+
import { Map, Range } from 'immutable';
3+
import { fileURLToPath } from 'url';
4+
import { dirname } from 'path';
5+
import path from 'node:path';
6+
import fsPromise from 'node:fs/promises';
7+
import * as tf from '@tensorflow/tfjs-node';
8+
9+
import { models, serialization } from "@epfml/discojs";
10+
11+
const __dirname = dirname(fileURLToPath(import.meta.url));
12+
const ASSET_FOLDER = path.join(__dirname, "..", "assets");
13+
const OUTPUT_FILENAME = path.join(ASSET_FOLDER, "model.json");
14+
15+
const GPT2_N_LAYER = 12;
16+
17+
18+
async function main() {
19+
const ONNX_URL = "https://huggingface.co/Xenova/gpt2/resolve/main/onnx/decoder_model.onnx?download=true"
20+
console.log(`Downloading ONNX model from ${ONNX_URL}...`);
21+
const response = await fetch(ONNX_URL);
22+
if (!response.ok)
23+
throw new Error(`Failed to fetch ONNX model from ${ONNX_URL}: ${response.statusText}`);
24+
const arrayBuffer = await response.arrayBuffer();
25+
const data = new Uint8Array(arrayBuffer);
26+
27+
console.log(`Download complete (${(data.length / 1024 / 1024).toFixed(2)} MB).`);
28+
console.log(`Decoding protobuf...`);
29+
30+
const onnxModel = onnx.ModelProto.decode(data)
31+
32+
if (!onnxModel.graph || !onnxModel.graph.initializer)
33+
throw new Error("No graph or tensors found in the ONNX model.");
34+
console.log('ONNX model loaded successfully');
35+
36+
37+
// Init empty TF.js model
38+
// Context length value from https://huggingface.co/Xenova/gpt2/blob/main/config.json
39+
const gptModel = new models.GPT({ modelType: 'gpt2', contextLength: 1024 });
40+
if (gptModel.config.nLayer != GPT2_N_LAYER)
41+
throw new Error(`ONNX conversion only supports GPT-2 with 12 layers, instead found ${gptModel.config.nLayer}.`);
42+
const gptLayersModel = gptModel.extract();
43+
44+
console.log("Converting ONNX tensors to TF.js tensors")
45+
// Layer name mapping between ONNX and TF.js
46+
const onnxTfjsMapping = createWeightNameMap();
47+
// Create a mapping between layer name and TF.js weight tensors
48+
let preTrainedWeights = Map<string, tf.Tensor>(); // layer name to weight tensor
49+
for (const tensor of onnxModel.graph.initializer) {
50+
if (tensor.name === undefined || tensor.name === null)
51+
throw new Error("Undefined layer named")
52+
53+
const tfjsName = onnxTfjsMapping.get(tensor.name);
54+
if (tfjsName === undefined)
55+
throw new Error(`Missing ONNX weight in layer mapping: ${tensor.name}`);
56+
if (preTrainedWeights.get(tfjsName))
57+
throw new Error(`Duplicate weight name found: ${tfjsName}`);
58+
59+
if (tensor.dims === undefined || tensor.dims === null)
60+
throw new Error(`Undefined layer dimensions for ${tensor.name}`)
61+
const dims = tensor.dims.map((d) => Number(d));
62+
const flatData = parseTensorData(tensor);
63+
const tfTensor = tf.tensor(flatData).reshape(dims)
64+
preTrainedWeights = preTrainedWeights.set(tfjsName, tfTensor);
65+
}
66+
67+
console.log("Initializing a new TFJS GPT-2 model...")
68+
if (preTrainedWeights.size !== onnxTfjsMapping.size)
69+
throw new Error(`Expected to load ${onnxTfjsMapping.size} weights, but loaded ${preTrainedWeights.size}.`);
70+
71+
// Overwrite the GPT-TF.js model weights with the ONNX weights
72+
if (gptLayersModel.weights.length !== onnxTfjsMapping.size)
73+
throw new Error(`Mismatch between TFJS and ONNX weight mapping weights.`);
74+
75+
const finalWeights = gptLayersModel.weights.map((weight, _i) => {
76+
const newTensor = preTrainedWeights.get(weight.name);
77+
if (newTensor === undefined)
78+
throw new Error(`Missing ${weight.name} in the ONNX weight`);
79+
return newTensor;
80+
});
81+
82+
gptLayersModel.setWeights(finalWeights); // shape or transpose mismatch will throw here
83+
84+
const encoded = await serialization.model.encode(gptModel)
85+
await fsPromise.mkdir(ASSET_FOLDER, { recursive: true})
86+
await fsPromise.writeFile(OUTPUT_FILENAME, encoded)
87+
console.log(`GPT-TFJS model saved to ${OUTPUT_FILENAME}`)
88+
}
89+
90+
/**
91+
*
92+
* @param tensor
93+
* @returns
94+
*/
95+
function parseTensorData(tensor: onnx.ITensorProto): Float32Array {
96+
// Check for raw data (common in larger models)
97+
if (tensor.rawData && tensor.rawData.length > 0) {
98+
const buffer = tensor.rawData.buffer.slice(
99+
tensor.rawData.byteOffset,
100+
tensor.rawData.byteOffset + tensor.rawData.byteLength
101+
);
102+
if (tensor.dataType != onnx.TensorProto.DataType.FLOAT) {
103+
throw new Error("found protobuf data type different from expected float 32.")
104+
}
105+
return new Float32Array(buffer);
106+
}
107+
// Fallback to specific field arrays if rawData is empty
108+
console.log("WARNING: protobuf raw data is empty, falling back on specific data fields.")
109+
if (tensor.floatData && tensor.floatData.length > 0) return new Float32Array(tensor.floatData);
110+
111+
throw new Error("protobuf raw data and float data are empty.")
112+
}
113+
114+
/**
115+
* Maps ONNX weight names to TFJS weight names.
116+
* This mapping is specific to GPT-2 137M with 12 layers.
117+
* @param prefix the TFJS model name specified in its GPTConfig, default is 'transformer'
118+
*/
119+
function createWeightNameMap(): Map<string, string> {
120+
let map = Map<string, string>();
121+
122+
map = map.set(`transformer.wte.weight`, `transformer/wte/embedding`);
123+
map = map.set(`transformer.wpe.weight`, `transformer/wpe/embeddings`);
124+
125+
Range(0, GPT2_N_LAYER).forEach(i => {
126+
const onnxPrefix = `transformer.h.${i}`;
127+
const tfjsPrefix = `transformer/h${i}`;
128+
map = map.set(`${onnxPrefix}.ln_1.weight`, `${tfjsPrefix}/ln_1/gamma`);
129+
map = map.set(`${onnxPrefix}.ln_1.bias`, `${tfjsPrefix}/ln_1/beta`);
130+
map = map.set(`${onnxPrefix}.attn.c_attn.weight`, `${tfjsPrefix}/attn/c_attn/kernel`);
131+
map = map.set(`${onnxPrefix}.attn.c_attn.bias`, `${tfjsPrefix}/attn/c_attn/bias`);
132+
map = map.set(`${onnxPrefix}.attn.c_proj.weight`, `${tfjsPrefix}/attn/c_proj/kernel`);
133+
map = map.set(`${onnxPrefix}.attn.c_proj.bias`, `${tfjsPrefix}/attn/c_proj/bias`);
134+
map = map.set(`${onnxPrefix}.ln_2.weight`, `${tfjsPrefix}/ln_2/gamma`);
135+
map = map.set(`${onnxPrefix}.ln_2.bias`, `${tfjsPrefix}/ln_2/beta`);
136+
map = map.set(`${onnxPrefix}.mlp.c_fc.weight`, `${tfjsPrefix}/mlp/c_fc/kernel`);
137+
map = map.set(`${onnxPrefix}.mlp.c_fc.bias`, `${tfjsPrefix}/mlp/c_fc/bias`);
138+
map = map.set(`${onnxPrefix}.mlp.c_proj.weight`, `${tfjsPrefix}/mlp/c_proj/kernel`);
139+
map = map.set(`${onnxPrefix}.mlp.c_proj.bias`, `${tfjsPrefix}/mlp/c_proj/bias`);
140+
});
141+
142+
map = map.set(`transformer.ln_f.weight`, `transformer/ln_f/gamma`);
143+
map = map.set(`transformer.ln_f.bias`, `transformer/ln_f/beta`);
144+
return map;
145+
}
146+
147+
148+
await main().catch(console.error);

0 commit comments

Comments
 (0)