Skip to content

Commit 697d1c8

Browse files
authored
[In Progress] Offline support for ml5 (#553)
* added modelLoader util * added proof of concept for offline sentiment * adds offline support for uNet * updates sketchRnn for offline support * adds offline support for yolo * adds modelUrl to config * adds package lock * fix imageSize ref * adds support for external URL defs * updates packagelock * fixes unet tests
1 parent d107aab commit 697d1c8

File tree

8 files changed

+127
-49
lines changed

8 files changed

+127
-49
lines changed

package-lock.json

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/FeatureExtractor/Mobilenet.js

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,17 @@ class Mobilenet {
7474
learningRate: options.learningRate || DEFAULTS.learningRate,
7575
batchSize: options.batchSize || DEFAULTS.batchSize,
7676
layer: options.layer || DEFAULTS.layer,
77-
alpha: options.alpha || DEFAULTS.alpha
77+
alpha: options.alpha || DEFAULTS.alpha,
7878
}
79+
80+
// for graph model
81+
this.model = null;
82+
this.url = MODEL_INFO[this.config.version][this.config.alpha];
83+
this.normalizationOffset = tf.scalar(127.5);
84+
85+
// check if a mobilenet URL is given
86+
this.mobilenetURL = options.mobilenetURL || `${BASE_URL}${this.config.version}_${this.config.alpha}_${IMAGE_SIZE}/model.json`;
87+
this.graphModelURL = options.graphModelURL || this.url;
7988
/**
8089
* Boolean value to check if the model is predicting.
8190
* @public
@@ -92,15 +101,17 @@ class Mobilenet {
92101
this.usageType = null;
93102
this.ready = callCallback(this.loadModel(), callback);
94103

95-
// for graph model
96-
this.model = null;
97-
this.url = MODEL_INFO[this.config.version][this.config.alpha];
98-
this.normalizationOffset = tf.scalar(127.5);
104+
99105
}
100106

101107
async loadModel() {
102-
this.mobilenet = await tf.loadLayersModel(`${BASE_URL}${this.config.version}_${this.config.alpha}_${IMAGE_SIZE}/model.json`);
103-
this.model = await tf.loadGraphModel(this.url, {fromTFHub: true});
108+
this.mobilenet = await tf.loadLayersModel(this.mobilenetURL);
109+
if(this.graphModelURL.includes('https://tfhub.dev/')){
110+
this.model = await tf.loadGraphModel(this.graphModelURL, {fromTFHub: true});
111+
} else {
112+
this.model = await tf.loadGraphModel(this.graphModelURL, {fromTFHub: false});
113+
}
114+
104115

105116

106117
const layer = this.mobilenet.getLayer(this.config.layer);

src/PoseNet/index.js

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class PoseNet extends EventEmitter {
6161
* @type {String}
6262
* @public
6363
*/
64+
this.modelUrl = options.modelUrl || null;
6465
this.architecture = options.architecture || DEFAULTS.architecture;
6566
this.detectionType = detectionType || options.detectionType || DEFAULTS.detectionType;
6667
this.imageScaleFactor = options.imageScaleFactor || DEFAULTS.imageScaleFactor;
@@ -85,7 +86,8 @@ class PoseNet extends EventEmitter {
8586
outputStride: this.outputStride,
8687
inputResolution: this.inputResolution,
8788
multiplier: this.multiplier,
88-
quantBytes: this.quantBytes
89+
quantBytes: this.quantBytes,
90+
modelUrl: this.modelUrl
8991
}
9092
} else {
9193
modelJson = {

src/Sentiment/index.js

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import * as tf from '@tensorflow/tfjs';
22
import callCallback from '../utils/callcallback';
3+
import modelLoader from '../utils/modelLoader';
4+
35
/**
4-
* Initializes the Sentiment demo.
5-
*/
6+
* Initializes the Sentiment demo.
7+
*/
68

79
const OOV_CHAR = 2;
810
const PAD_CHAR = 0;
@@ -51,19 +53,32 @@ class Sentiment {
5153
}
5254

5355
/**
54-
* Initializes the Sentiment demo.
55-
*/
56+
* Initializes the Sentiment demo.
57+
*/
5658

5759
async loadModel(modelName) {
5860

61+
const movieReviews = {
62+
model: null,
63+
metadata: null,
64+
}
65+
5966
if (modelName.toLowerCase() === 'moviereviews') {
67+
68+
movieReviews.model = 'https://storage.googleapis.com/tfjs-models/tfjs/sentiment_cnn_v1/model.json';
69+
movieReviews.metadata = 'https://storage.googleapis.com/tfjs-models/tfjs/sentiment_cnn_v1/metadata.json';
70+
71+
} else if(modelLoader.isAbsoluteURL(modelName) === true ) {
72+
const modelPath = modelLoader.getModelPath(modelName);
73+
74+
movieReviews.model = `${modelPath}/model.json`;
75+
movieReviews.metadata = `${modelPath}/metadata.json`;
76+
77+
} else {
78+
console.error('problem loading model');
79+
return this;
80+
}
6081

61-
const movieReviews = {
62-
model:
63-
'https://storage.googleapis.com/tfjs-models/tfjs/sentiment_cnn_v1/model.json',
64-
metadata:
65-
'https://storage.googleapis.com/tfjs-models/tfjs/sentiment_cnn_v1/metadata.json',
66-
};
6782

6883
/**
6984
* The model being used.
@@ -80,9 +95,6 @@ class Sentiment {
8095
this.wordIndex = sentimentMetadata.word_index;
8196
this.vocabularySize = sentimentMetadata.vocabulary_size;
8297

83-
} else {
84-
console.error('problem loading model')
85-
}
8698
return this;
8799
}
88100

@@ -112,15 +124,12 @@ class Sentiment {
112124
const score = predictOut.dataSync()[0];
113125
predictOut.dispose();
114126

115-
return { score };
127+
return {
128+
score
129+
};
116130
}
117131
}
118132

119-
const sentiment = (modelName, callback) => new Sentiment( modelName, callback ) ;
120-
121-
export default sentiment;
122-
123-
124-
125-
133+
const sentiment = (modelName, callback) => new Sentiment(modelName, callback);
126134

135+
export default sentiment;

src/SketchRNN/index.js

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,21 @@ SketchRNN
1212
import * as ms from '@magenta/sketch';
1313
import callCallback from '../utils/callcallback';
1414
import modelPaths from './models';
15+
import modelLoader from '../utils/modelLoader';
1516

16-
const PATH_START_LARGE = 'https://storage.googleapis.com/quickdraw-models/sketchRNN/large_models/';
17-
const PATH_START_SMALL = 'https://storage.googleapis.com/quickdraw-models/sketchRNN/models/';
18-
const PATH_END = '.gen.json';
17+
// const PATH_START_LARGE = 'https://storage.googleapis.com/quickdraw-models/sketchRNN/large_models/';
18+
// const PATH_START_SMALL = 'https://storage.googleapis.com/quickdraw-models/sketchRNN/models/';
19+
// const PATH_END = '.gen.json';
20+
21+
22+
const DEFAULTS = {
23+
modelPath: 'https://storage.googleapis.com/quickdraw-models/sketchRNN/large_models/',
24+
modelPath_large: 'https://storage.googleapis.com/quickdraw-models/sketchRNN/models/',
25+
modelPath_small: 'https://storage.googleapis.com/quickdraw-models/sketchRNN/models/',
26+
PATH_END: '.gen.json',
27+
temperature: 0.65,
28+
pixelFactor: 3.0,
29+
}
1930

2031
class SketchRNN {
2132
/**
@@ -27,21 +38,37 @@ class SketchRNN {
2738
*/
2839
constructor(model, callback, large = true) {
2940
let checkpointUrl = model;
30-
if (modelPaths.has(checkpointUrl)) {
31-
checkpointUrl = (large ? PATH_START_LARGE : PATH_START_SMALL) + checkpointUrl + PATH_END;
32-
}
33-
this.defaults = {
41+
42+
this.config = {
3443
temperature: 0.65,
3544
pixelFactor: 3.0,
45+
modelPath: DEFAULTS.modelPath,
46+
modelPath_small: DEFAULTS.modelPath_small,
47+
modelPath_large: DEFAULTS.modelPath_large,
48+
PATH_END: DEFAULTS.PATH_END,
3649
};
37-
this.model = new ms.SketchRNN(checkpointUrl);
50+
51+
52+
if(modelLoader.isAbsoluteURL(checkpointUrl) === true){
53+
const modelPath = modelLoader.getModelPath(checkpointUrl);
54+
this.config.modelPath = modelPath;
55+
56+
} else if(modelPaths.has(checkpointUrl)) {
57+
checkpointUrl = (large ? this.config.modelPath : this.config.modelPath_small) + checkpointUrl + this.config.PATH_END;
58+
this.config.modelPath = checkpointUrl;
59+
} else {
60+
console.log('no model found!');
61+
return this;
62+
}
63+
64+
this.model = new ms.SketchRNN(this.config.modelPath);
3865
this.penState = this.model.zeroInput();
3966
this.ready = callCallback(this.model.initialize(), callback);
4067
}
4168

4269
async generateInternal(options, strokes) {
43-
const temperature = +options.temperature || this.defaults.temperature;
44-
const pixelFactor = +options.pixelFactor || this.defaults.pixelFactor;
70+
const temperature = +options.temperature || this.config.temperature;
71+
const pixelFactor = +options.pixelFactor || this.config.pixelFactor;
4572

4673
await this.ready;
4774
if (!this.rnnState) {

src/UNET/index.js

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ import callCallback from '../utils/callcallback';
1212
import { array3DToImage } from '../utils/imageUtilities';
1313
import p5Utils from '../utils/p5Utils';
1414

15-
const URL = 'https://raw.githubusercontent.com/zaidalyafeai/HostedModels/master/unet-128/model.json';
16-
const imageSize = 128;
15+
const DEFAULTS = {
16+
modelPath: 'https://raw.githubusercontent.com/zaidalyafeai/HostedModels/master/unet-128/model.json',
17+
imageSize: 128
18+
}
1719

1820
class UNET {
1921
/**
@@ -26,11 +28,15 @@ class UNET {
2628
constructor(video, options, callback) {
2729
this.modelReady = false;
2830
this.isPredicting = false;
31+
this.config = {
32+
modelPath: typeof options.modelPath !== 'undefined' ? options.modelPath : DEFAULTS.modelPath,
33+
imageSize: typeof options.imageSize !== 'undefined' ? options.imageSize : DEFAULTS.imageSize
34+
};
2935
this.ready = callCallback(this.loadModel(), callback);
3036
}
3137

3238
async loadModel() {
33-
this.model = await tf.loadLayersModel(URL);
39+
this.model = await tf.loadLayersModel(this.config.modelPath);
3440
this.modelReady = true;
3541
return this;
3642
}
@@ -80,7 +86,7 @@ class UNET {
8086
const tensor = tf.tidy(() => {
8187
// preprocess
8288
const tfImage = tf.browser.fromPixels(imgToPredict).toFloat();
83-
const resizedImg = tf.image.resizeBilinear(tfImage, [imageSize, imageSize]);
89+
const resizedImg = tf.image.resizeBilinear(tfImage, [this.config.imageSize, this.config.imageSize]);
8490
const normTensor = resizedImg.div(tf.scalar(255));
8591

8692
const batchedImage = normTensor.expandDims(0);
@@ -102,7 +108,7 @@ class UNET {
102108
let image;
103109

104110
if (p5Utils.checkP5()) {
105-
const blob1 = await p5Utils.rawToBlob(raw, imageSize, imageSize);
111+
const blob1 = await p5Utils.rawToBlob(raw, this.config.imageSize, this.config.imageSize);
106112
const p5Image1 = await p5Utils.blobToP5Image(blob1);
107113
image = p5Image1;
108114
}
@@ -129,7 +135,7 @@ const uNet = (videoOr, optionsOr, cb) => {
129135
callback = videoOr;
130136
} else if (typeof videoOr === 'object') {
131137
options = videoOr;
132-
}
138+
}
133139

134140
if (typeof optionsOr === 'object') {
135141
options = optionsOr;

src/YOLO/index.js

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import Video from '../utils/Video';
1414
import { imgToTensor } from '../utils/imageUtilities';
1515
import callCallback from '../utils/callcallback';
1616
import CLASS_NAMES from './../utils/COCO_CLASSES';
17+
import modelLoader from '../utils/modelLoader';
1718

1819
import {
1920
nonMaxSuppression,
@@ -23,9 +24,8 @@ import {
2324
ANCHORS,
2425
} from './postprocess';
2526

26-
const URL = 'https://raw.githubusercontent.com/ml5js/ml5-data-and-training/master/models/YOLO/model.json';
27-
2827
const DEFAULTS = {
28+
modelUrl: 'https://raw.githubusercontent.com/ml5js/ml5-data-and-training/master/models/YOLO/model.json',
2929
filterBoxesThreshold: 0.01,
3030
IOUThreshold: 0.4,
3131
classProbThreshold: 0.4,
@@ -51,6 +51,7 @@ class YOLOBase extends Video {
5151
constructor(video, options, callback) {
5252
super(video, imageSize);
5353

54+
this.modelUrl = options.modelUrl || DEFAULTS.modelUrl;
5455
this.filterBoxesThreshold = options.filterBoxesThreshold || DEFAULTS.filterBoxesThreshold;
5556
this.IOUThreshold = options.IOUThreshold || DEFAULTS.IOUThreshold;
5657
this.classProbThreshold = options.classProbThreshold || DEFAULTS.classProbThreshold;
@@ -64,7 +65,15 @@ class YOLOBase extends Video {
6465
if (this.videoElt && !this.video) {
6566
this.video = await this.loadVideo();
6667
}
67-
this.model = await tf.loadLayersModel(URL);
68+
69+
if(modelLoader.isAbsoluteURL(this.modelUrl) === true){
70+
this.model = await tf.loadLayersModel(this.modelUrl);
71+
} else {
72+
const modelPath = modelLoader.getModelPath(this.modelUrl);
73+
this.modelUrl = `${modelPath}/model.json`;
74+
this.model = await tf.loadLayersModel(this.modelUrl);
75+
}
76+
6877
this.modelReady = true;
6978
return this;
7079
}

src/utils/modelLoader.js

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
function isAbsoluteURL(str) {
2+
const pattern = new RegExp('^(?:[a-z]+:)?//', 'i');
3+
return !!pattern.test(str);
4+
}
5+
6+
function getModelPath(absoluteOrRelativeUrl) {
7+
const modelJsonPath = isAbsoluteURL(absoluteOrRelativeUrl) ? absoluteOrRelativeUrl : window.location.pathname + absoluteOrRelativeUrl
8+
return modelJsonPath;
9+
}
10+
11+
export default {
12+
isAbsoluteURL,
13+
getModelPath
14+
}

0 commit comments

Comments
 (0)