Skip to content

Commit 3788b18

Browse files
authored
simple phi3 chat example (#424)
* add phi3 ort-web example * add phi3 ort-web example * add phi3 ort-web example * add phi3 ort-web example * fix ort package version * pin ort, limit width of user messages * switch to webpack * update readme to reflect webpack * fix wasm path * add olive instructions to readme * new naming convention and location for the model * future proof build config
1 parent 2a76152 commit 3788b18

File tree

8 files changed

+752
-0
lines changed

8 files changed

+752
-0
lines changed

js/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,6 @@ Click links for README of each examples.
4949
* [Facebook Segment-Anything](segment-anything) - demonstrates how to run [segment-anything](https://github.com/facebookresearch/segment-anything) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu.
5050

5151
* [Stable Diffusion Turbo](sd-turbo) - demonstrates how to run [Stable Diffusion Turbo](https://huggingface.co/stabilityai/sd-turbo) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu.
52+
53+
* [Phi-3-mini-4k-instruct](chat) - demonstrates how to run [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu.
54+

js/chat/README.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Local Chat using Phi3, ONNX Runtime Web and WebGPU
2+
3+
This repository contains an example of running [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) in your browser using [ONNX Runtime Web](https://github.com/microsoft/onnxruntime) with WebGPU.
4+
5+
You can try out the live demo [here](https://guschmue.github.io/ort-webgpu/chat/index.html).
6+
7+
We keep this example simple and use the onnxruntime-web api directly without a
8+
higher level framework like [transformers.js](https://github.com/xenova/transformers.js).
9+
10+
## Getting Started
11+
12+
### Prerequisites
13+
14+
Ensure that you have [Node.js](https://nodejs.org/) installed on your machine.
15+
16+
### Installation
17+
18+
Install the required dependencies:
19+
20+
```sh
21+
npm install
22+
```
23+
24+
### Building the project
25+
26+
Build the project:
27+
28+
```sh
29+
npm run build
30+
```
31+
32+
The output can be found in the ***dist*** directory.
33+
34+
### Building for developent
35+
36+
```sh
37+
npm run dev
38+
```
39+
40+
This will build the project and start a dev server.
41+
Point your browser to http://localhost:8080/.
42+
43+
### The Phi3 ONNX Model
44+
45+
The model used in this example is hosted on [Hugging Face](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx-web). It is slightly different than the ONNX model for CUDA or CPU:
46+
1. The model output 'logits' is kept as float32 (even for float16 models) since Javascript does not support float16.
47+
2. Our WebGPU implementation uses the custom Multiheaded Attention operator instread of Group Query Attention.
48+
3. Phi3 is larger then 2GB and we need to use external data files. To keep them cacheable in the browser,
49+
both model.onnx and model.onnx.data are kept under 2GB.
50+
51+
The model was created using the [ONNX genai model builder](https://github.com/microsoft/onnxruntime-genai/tree/main/src/python/py/models).
52+
53+
If you like to create the model yourself, you can use [Olive](https://github.com/microsoft/Olive/).
54+
An example how to create the model for ONNX Runtime Web with Olive can be found [here](https://github.com/microsoft/Olive/tree/main/examples/phi3).

js/chat/index.html

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
<!doctype html>
2+
<html lang="en">
3+
4+
<head>
5+
<meta charset="UTF-8" />
6+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
7+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
8+
integrity="sha384-4bw+/aepP/YC94hEpVNVgiZdgIC5+VKNBQNGCHeKRQN+PtmoHDEXuppvnDJzQIu9" crossorigin="anonymous" />
9+
<link rel="stylesheet" href="main.css">
10+
11+
<title>Chat with onnxruntime-web</title>
12+
</head>
13+
14+
<body data-bs-theme="dark">
15+
<div id="root"></div>
16+
17+
<div class="container">
18+
<div class="row pt-3">
19+
<div class="col-md-8 col-12">
20+
<h2>Chat with onnxruntime-web</h2>
21+
</div>
22+
<div id="status">
23+
</div>
24+
</div>
25+
<div id="scroll-wrapper">
26+
<div id="chat-container" class="card">
27+
<div class="card-body">
28+
<div id="chat-history"></div>
29+
</div>
30+
</div>
31+
</div>
32+
</div>
33+
<div class="container p-0 card" id="input-area">
34+
<div class="input-group">
35+
<textarea class="form-control" id="user-input" placeholder="Type your question here ..."></textarea>
36+
<button id="send-button" class="btn btn-primary">Send</button>
37+
</div>
38+
</div>
39+
40+
<script type="module" src="dist/main.js"></script>
41+
</body>
42+
43+
</html>

js/chat/llm.js

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import * as ort from 'onnxruntime-web/webgpu';
2+
3+
ort.env.wasm.numThreads = 1;
4+
ort.env.wasm.simd = true;
5+
ort.env.wasm.wasmPaths = document.location.pathname.replace('index.html', '') + 'dist/';
6+
7+
8+
function log(i) { console.log(i); document.getElementById('status').innerText += `\n${i}`; }
9+
10+
//
11+
// load file from server or cache
12+
//
13+
async function fetchAndCache(url) {
14+
try {
15+
const cache = await caches.open("onnx");
16+
let cachedResponse = await cache.match(url);
17+
if (cachedResponse === undefined) {
18+
log(`${url} (network)`);
19+
const buffer = await fetch(url).then(response => response.arrayBuffer());
20+
try {
21+
await cache.put(url, new Response(buffer));
22+
} catch (error) {
23+
console.error(error);
24+
}
25+
return buffer;
26+
}
27+
log(`${url} (cached)`);
28+
const data = await cachedResponse.arrayBuffer();
29+
return data;
30+
} catch (error) {
31+
log(`can't fetch ${url}`);
32+
throw error;
33+
}
34+
}
35+
36+
//
37+
// class to handle a large language model on top of onnxruntime-web
38+
//
39+
export class LLM {
40+
sess = undefined;
41+
profiler = false;
42+
feed = {};
43+
output_tokens = [];
44+
eos = 2;
45+
need_position_ids = true;
46+
stop = false;
47+
kv_dims = [];
48+
dtype = "float16";
49+
max_tokens = 9999;
50+
51+
constructor() {
52+
}
53+
54+
async load(model, options) {
55+
const provider = options.provider || "webgpu";
56+
const verbose = options.verbose;
57+
const local = options.local;
58+
const hasFP16 = (provider === "wasm") ? false : options.hasFP16;
59+
this.profiler = options.profiler;
60+
61+
const model_path = (local) ? "models/" + model.path : "https://huggingface.co/" + model.path + "/resolve/main";
62+
let model_file = model.file || "model";
63+
model_file = (hasFP16) ? model_file + "_q4f16.onnx" : model_file + "_q4.onnx";
64+
65+
log(`loading... ${model.name}, ${provider}`);
66+
const json_bytes = await fetchAndCache(model_path + "/config.json");
67+
let textDecoder = new TextDecoder();
68+
const model_config = JSON.parse(textDecoder.decode(json_bytes));
69+
70+
const model_bytes = await fetchAndCache(model_path + "/onnx/" + model_file);
71+
const externaldata = (model.externaldata) ? await fetchAndCache(model_path + "/onnx/" + model_file + '_data') : false;
72+
let modelSize = model_bytes.byteLength;
73+
if (externaldata) {
74+
modelSize += externaldata.byteLength;
75+
}
76+
log(`model size ${Math.round(modelSize / 1024 / 1024)} MB`);
77+
78+
const opt = {
79+
executionProviders: [provider],
80+
preferredOutputLocation: {},
81+
}
82+
83+
switch (provider) {
84+
case "webgpu":
85+
for (let i = 0; i < model_config.num_hidden_layers; ++i) {
86+
opt.preferredOutputLocation[`present.${i}.key`] = 'gpu-buffer';
87+
opt.preferredOutputLocation[`present.${i}.value`] = 'gpu-buffer';
88+
}
89+
break;
90+
}
91+
92+
if (externaldata !== undefined) {
93+
opt.externalData = [
94+
{
95+
data: externaldata,
96+
path: model_file + "_data",
97+
},
98+
]
99+
}
100+
if (verbose) {
101+
opt.logSeverityLevel = 0;
102+
opt.logVerbosityLevel = 0;
103+
ort.env.logLevel = "verbose";
104+
}
105+
106+
ort.env.webgpu.profiling = {}
107+
if (this.profiler) {
108+
opt.enableProfiling = true;
109+
ort.env.webgpu.profilingMode = 'default';
110+
ort.env.webgpu.profiling.mode = 'default';
111+
}
112+
113+
this.sess = await ort.InferenceSession.create(model_bytes, opt);
114+
this.eos = model_config.eos_token_id;
115+
this.kv_dims = [1, model_config.num_key_value_heads, 0, model_config.hidden_size / model_config.num_attention_heads];
116+
this.dtype = (hasFP16) ? "float16" : "float32";
117+
this.num_layers = model_config.num_hidden_layers;
118+
this.initilize_feed();
119+
}
120+
121+
initilize_feed() {
122+
const feed = this.feed;
123+
124+
// dispose of previous gpu buffers
125+
for (const name in feed) {
126+
const t = feed[name];
127+
if (t.location === 'gpu-buffer') {
128+
t.dispose();
129+
}
130+
}
131+
this.feed = {};
132+
// key value cache is zero copy, just pass gpu buffer as referece
133+
const empty = (this.dtype === "float16") ? new Uint16Array() : [];
134+
for (let i = 0; i < this.num_layers; ++i) {
135+
this.feed[`past_key_values.${i}.key`] = new ort.Tensor(this.dtype, empty, this.kv_dims)
136+
this.feed[`past_key_values.${i}.value`] = new ort.Tensor(this.dtype, empty, this.kv_dims)
137+
}
138+
this.output_tokens = [];
139+
}
140+
141+
//
142+
// poor mens argmax
143+
argmax(t) {
144+
const arr = t.data;
145+
const start = t.dims[2] * (t.dims[1] - 1);
146+
let max = arr[start];
147+
let maxidx = 0;
148+
149+
for (let i = 0; i < t.dims[2]; i++) {
150+
const val = arr[i + start];
151+
if (!isFinite(val)) {
152+
throw new Error("found infinitive in logits");
153+
}
154+
if (val > max) {
155+
max = arr[i + start];
156+
maxidx = i;
157+
}
158+
}
159+
return maxidx;
160+
}
161+
162+
//
163+
// update key value cache
164+
//
165+
update_kv_cache(feed, outputs) {
166+
for (const name in outputs) {
167+
if (name.startsWith('present')) {
168+
let newName = name.replace('present', 'past_key_values');
169+
// dispose previous gpu buffers
170+
const t = feed[newName];
171+
if (t.location === 'gpu-buffer') {
172+
t.dispose();
173+
}
174+
feed[newName] = outputs[name];
175+
}
176+
}
177+
}
178+
179+
//
180+
// tell generate to stop()
181+
//
182+
abort() {
183+
this.stop = true;
184+
}
185+
186+
//
187+
// prefill prompt and generate tokens, greedy search only
188+
//
189+
async generate(tokens, callback, options) {
190+
const max_tokens = options.max_tokens || 256;
191+
const feed = this.feed;
192+
const input_ids = new ort.Tensor('int64', BigInt64Array.from(tokens.map(BigInt)), [1, tokens.length]);
193+
feed['input_ids'] = input_ids;
194+
this.stop = false;
195+
196+
this.output_tokens.push(...input_ids.data);
197+
198+
let last_token = 0n;
199+
let seqlen = this.output_tokens.length;
200+
const input_len = input_ids.size;
201+
202+
if (this.need_position_ids) {
203+
feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from({ length: input_len }, (_, i) => BigInt(seqlen - input_len + i)), [1, input_len]);
204+
}
205+
206+
while (last_token != this.eos && last_token != 32007 && seqlen < max_tokens && !this.stop) {
207+
seqlen = this.output_tokens.length;
208+
feed['attention_mask'] = new ort.Tensor('int64', BigInt64Array.from({ length: seqlen }, () => 1n), [1, seqlen]);
209+
const outputs = await this.sess.run(feed);
210+
last_token = BigInt(this.argmax(outputs.logits));
211+
this.output_tokens.push(last_token);
212+
if (callback && !this.profiler) {
213+
callback(this.output_tokens);
214+
}
215+
this.update_kv_cache(feed, outputs);
216+
feed['input_ids'] = new ort.Tensor('int64', BigInt64Array.from([last_token]), [1, 1]);
217+
if (this.need_position_ids) {
218+
feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from([BigInt(seqlen)]), [1, 1]);
219+
}
220+
}
221+
if (this.profiler) {
222+
this.sess.endProfiling();
223+
}
224+
return this.output_tokens;
225+
}
226+
}

js/chat/main.css

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
body {
2+
color: #f5f5f5;
3+
font-family: 'Arial', sans-serif;
4+
}
5+
6+
.user-message {
7+
background-color: rgb(86, 144, 163);
8+
color: white;
9+
padding: 10px;
10+
border-radius: 10px;
11+
white-space: pre-wrap;
12+
width: fit-content;
13+
}
14+
15+
.response-message {
16+
background-color: rgb(62, 62, 62);
17+
color: white;
18+
padding: 10px;
19+
border-radius: 10px;
20+
padding-right: 20px;
21+
position: relative;
22+
margin-right: auto;
23+
}
24+
25+
.response-message p {
26+
margin-right: 40px;
27+
}
28+
29+
#chat-container {
30+
display: none;
31+
margin: 0 auto;
32+
overflow: auto;
33+
}
34+
35+
#chat-history {
36+
display: flex;
37+
flex-direction: column;
38+
}
39+
40+
.copy-button {
41+
position: absolute;
42+
bottom: 5px;
43+
right: 5px;
44+
margin: 0 5px 5px 0;
45+
}
46+
47+
#scroll-wrapper {
48+
padding-bottom: 5.5rem;
49+
}
50+
51+
#input-area {
52+
position: fixed;
53+
bottom: 0;
54+
margin-bottom: 5px;
55+
left: 50%;
56+
transform: translateX(-50%);
57+
}

0 commit comments

Comments
 (0)