Skip to content

Commit 049078a

Browse files
committed
LanguageModel: Require explicitly specifying the model name used
Besides providing a model name, the user can also pass an object containing the URL to a custom model. In both cases, they're explicit about the model they're exploring. As suggested by @shiffman
1 parent ed0ab03 commit 049078a

File tree

6 files changed

+30
-21
lines changed

6 files changed

+30
-21
lines changed

examples/LanguageModel/sketch.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function setup() {
77
createCanvas(400, 400);
88
background(0);
99

10-
lm = ml5.languageModel(onModelLoaded);
10+
lm = ml5.languageModel('TinyStories-15M', onModelLoaded);
1111
}
1212

1313
function draw() {

examples/LanguageModelAsync/sketch.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ let lm;
33
async function setup() {
44
noCanvas();
55

6-
lm = await ml5.languageModel();
6+
lm = await ml5.languageModel('TinyStories-15M');
77
console.log('Model loaded');
88

99
select('#generate').mouseClicked(generateText);

examples/LanguageModelEvent/sketch.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function setup() {
77
createCanvas(400, 400);
88
background(0);
99

10-
lm = ml5.languageModel(onModelLoaded);
10+
lm = ml5.languageModel('TinyStories-15M', onModelLoaded);
1111
}
1212

1313
function draw() {

examples/LanguageModelManual/sketch.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ let numOptions = 40;
44
async function setup() {
55
noCanvas();
66

7-
lm = await ml5.languageModel(onModelLoaded);
7+
lm = await ml5.languageModel('TinyStories-15M', onModelLoaded);
88
}
99

1010
function draw() {

examples/LanguageModelManualAsync/sketch.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ let numOptions = 40;
44
async function setup() {
55
noCanvas();
66

7-
lm = await ml5.languageModel();
7+
lm = await ml5.languageModel('TinyStories-15M');
88
console.log('Model loaded');
99

1010
select('#generate').mouseClicked(generateText);

src/LanguageModel/index.js

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,47 @@
55

66
import { EventEmitter } from "events";
77
import callCallback from "../utils/callcallback";
8-
import handleArguments from "../utils/handleArguments";
98

109
import Llama2 from './llama2.js';
1110
import Llama2Wasm from './llama2.wasm';
1211
import Llama2Data from './llama2.data';
1312

1413

1514
class LanguageModel extends EventEmitter {
16-
constructor(optionsOrCb, cb) {
15+
constructor(modelNameOrOptions, callback) {
1716
super();
1817

1918
this.options = {
20-
modelUrl: 'https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin', // if set, model.bin will be preloaded from provided URL (assumed to be embedded in llama2.data if not)
19+
modelUrl: '', // if set, model.bin will be preloaded from provided URL (assumed to be embedded in llama2.data if not)
2120
tokenizerUrl: '', // if set, tokenizer.bin will be preloaded from provided URL (assumed to be embedded in llama2.data if not)
2221
steps: 0, // how many tokens to generate (defaults to model's maximum)
2322
temperature: 0.9, // 0.0 = (deterministic) argmax sampling, 1.0 = baseline
2423
stopOnBosOrEos: true, // stop when encountering beginning-of-sequence or end-of-sequence token
2524
};
2625

2726
// handle arguments
28-
let callback;
29-
if (typeof optionsOrCb === 'function') {
30-
callback = optionsOrCb;
31-
} else {
32-
if (typeof optionsOrCb === 'object') {
33-
this.options.modelUrl = (typeof optionsOrCb.modelUrl === 'string') ? optionsOrCb.modelUrl : this.options.modelUrl;
34-
this.options.tokenizerUrl = (typeof optionsOrCb.tokenizerUrl === 'string') ? optionsOrCb.tokenizerUrl : this.options.tokenizerUrl;
35-
}
36-
if (typeof cb === 'function') {
37-
callback = cb;
27+
if (typeof modelNameOrOptions === 'string') {
28+
switch (modelNameOrOptions) {
29+
// see https://huggingface.co/karpathy/tinyllamas for TinyStories-*
30+
case 'TinyStories-15M':
31+
this.options.modelUrl = 'https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin';
32+
break;
33+
case 'TinyStories-42M':
34+
this.options.modelUrl = 'https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin';
35+
break;
36+
case 'TinyStories-110M':
37+
this.options.modelUrl = 'https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin';
38+
break;
39+
default:
40+
throw 'Unrecognized model ' + modelNameOrUrl + ', try e.g. TinyStories-15M';
3841
}
42+
} else if (typeof modelNameOrUrl === 'object') {
43+
this.options.modelUrl = (typeof modelNameOrOptions.modelUrl === 'string') ? modelNameOrOptions.modelUrl : this.options.modelUrl;
44+
this.options.tokenizerUrl = (typeof modelNameOrOptions.tokenizerUrl === 'string') ? modelNameOrOptions.tokenizerUrl : this.options.tokenizerUrl;
45+
}
46+
47+
if (!this.options.modelUrl) {
48+
throw 'You need to provide the name of the model to load, e.g. TinyStories-15M';
3949
}
4050

4151
this.out = '';
@@ -279,9 +289,8 @@ class LanguageModel extends EventEmitter {
279289
* exposes LanguageModel class through function
280290
* @returns {Object|Promise<Boolean>} A new LanguageModel instance
281291
*/
282-
const languageModel = (...inputs) => {
283-
const { options = {}, callback } = handleArguments(...inputs);
284-
const instance = new LanguageModel(options, callback);
292+
const languageModel = (modelNameOrOptions, callback) => {
293+
const instance = new LanguageModel(modelNameOrOptions, callback);
285294
return instance;
286295
};
287296

0 commit comments

Comments
 (0)