|
4 | 4 | // https://opensource.org/licenses/MIT
|
5 | 5 |
|
6 | 6 | /*
|
7 |
| -A General Feature Extractor class |
| 7 | +A class that extract features from Mobilenet |
8 | 8 | */
|
9 | 9 |
|
10 |
| -class FeatureExtractor extends Video { |
11 |
| - constructor(model, videoOrCallback, optionsOrCallback = {}, cb = () => {}) { |
12 |
| - super(video, IMAGESIZE); |
| 10 | +import * as tf from '@tensorflow/tfjs'; |
| 11 | +import Video from './../utils/Video'; |
| 12 | +import { IMAGENET_CLASSES } from './../utils/IMAGENET_CLASSES'; |
| 13 | +import { imgToTensor } from '../utils/imageUtilities'; |
| 14 | + |
| 15 | +const IMAGESIZE = 224; |
| 16 | +const DEFAULTS = { |
| 17 | + version: 1, |
| 18 | + alpha: 1.0, |
| 19 | + topk: 3, |
| 20 | + learningRate: 0.0001, |
| 21 | + hiddenUnits: 100, |
| 22 | + epochs: 20, |
| 23 | + numClasses: 2, |
| 24 | + batchSize: 0.4, |
| 25 | +}; |
| 26 | + |
| 27 | +class Mobilenet { |
| 28 | + constructor(options, callback) { |
| 29 | + this.mobilenet = null; |
| 30 | + this.modelPath = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json'; |
| 31 | + this.topKPredictions = 10; |
| 32 | + this.modelLoaded = false; |
| 33 | + this.hasAnyTrainedClass = false; |
| 34 | + this.customModel = null; |
| 35 | + this.epochs = options.epochs || DEFAULTS.epochs; |
| 36 | + this.hiddenUnits = options.hiddenUnits || DEFAULTS.hiddenUnits; |
| 37 | + this.numClasses = options.numClasses || DEFAULTS.numClasses; |
| 38 | + this.learningRate = options.learningRate || DEFAULTS.learningRate; |
| 39 | + this.batchSize = options.batchSize || DEFAULTS.batchSize; |
| 40 | + this.isPredicting = false; |
| 41 | + this.mapStringToIndex = []; |
| 42 | + this.usageType = null; |
| 43 | + |
| 44 | + this.loadModel().then((net) => { |
| 45 | + this.modelLoaded = true; |
| 46 | + this.mobilenetFeatures = net; |
| 47 | + callback(); |
| 48 | + }); |
| 49 | + } |
| 50 | + |
| 51 | + async loadModel() { |
| 52 | + this.mobilenet = await tf.loadModel(this.modelPath); |
| 53 | + const layer = this.mobilenet.getLayer('conv_pw_13_relu'); |
| 54 | + if (this.video) { |
| 55 | + tf.tidy(() => this.mobilenet.predict(imgToTensor(this.video))); // Warm up |
| 56 | + } |
| 57 | + return tf.model({ inputs: this.mobilenet.inputs, outputs: layer.output }); |
| 58 | + } |
| 59 | + |
| 60 | + asClassifier(video, callback) { |
| 61 | + this.usageType = 'classifier'; |
| 62 | + return this.loadVideo(video, callback); |
| 63 | + } |
| 64 | + |
| 65 | + asRegressor(video, callback) { |
| 66 | + this.usageType = 'regressor'; |
| 67 | + return this.loadVideo(video, callback); |
| 68 | + } |
| 69 | + |
| 70 | + loadVideo(video, callback = () => {}) { |
| 71 | + let inputVideo = null; |
| 72 | + |
| 73 | + if (video instanceof HTMLVideoElement) { |
| 74 | + inputVideo = video; |
| 75 | + } else if (typeof video === 'object' && video.elt instanceof HTMLVideoElement) { |
| 76 | + inputVideo = video.elt; |
| 77 | + } |
| 78 | + |
| 79 | + if (inputVideo) { |
| 80 | + const vid = new Video(inputVideo, IMAGESIZE); |
| 81 | + vid.loadVideo().then(async () => { |
| 82 | + this.video = vid.video; |
| 83 | + callback(); |
| 84 | + }); |
| 85 | + } |
| 86 | + |
| 87 | + return this; |
| 88 | + } |
| 89 | + |
| 90 | + addImage(inputOrLabel, labelOrCallback, cb = () => {}) { |
| 91 | + let imgToAdd; |
| 92 | + let label; |
| 93 | + let callback = cb; |
| 94 | + |
| 95 | + if (inputOrLabel instanceof HTMLImageElement || inputOrLabel instanceof HTMLVideoElement) { |
| 96 | + imgToAdd = inputOrLabel; |
| 97 | + } else if (typeof inputOrLabel === 'object' && (inputOrLabel.elt instanceof HTMLImageElement || inputOrLabel.elt instanceof HTMLVideoElement)) { |
| 98 | + imgToAdd = inputOrLabel; |
| 99 | + } else if (typeof inputOrLabel === 'string' || typeof inputOrLabel === 'number') { |
| 100 | + imgToAdd = this.video; |
| 101 | + label = inputOrLabel; |
| 102 | + } |
| 103 | + |
| 104 | + if (typeof labelOrCallback === 'string' || typeof labelOrCallback === 'number') { |
| 105 | + label = labelOrCallback; |
| 106 | + } else if (typeof labelOrCallback === 'function') { |
| 107 | + callback = labelOrCallback; |
| 108 | + } |
| 109 | + |
| 110 | + if (typeof label === 'string') { |
| 111 | + if (!this.mapStringToIndex.includes(label)) { |
| 112 | + label = this.mapStringToIndex.push(label) - 1; |
| 113 | + } else { |
| 114 | + label = this.mapStringToIndex.indexOf(label); |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + if (this.modelLoaded) { |
| 119 | + tf.tidy(() => { |
| 120 | + const processedImg = imgToTensor(imgToAdd); |
| 121 | + const prediction = this.mobilenetFeatures.predict(processedImg); |
| 122 | + |
| 123 | + let y; |
| 124 | + if (this.usageType === 'classifier') { |
| 125 | + y = tf.tidy(() => tf.oneHot(tf.tensor1d([label], 'int32'), this.numClasses)); |
| 126 | + } else if (this.usageType === 'regressor') { |
| 127 | + y = tf.tidy(() => tf.tensor2d([[label]])); |
| 128 | + } |
| 129 | + |
| 130 | + if (this.xs == null) { |
| 131 | + this.xs = tf.keep(prediction); |
| 132 | + this.ys = tf.keep(y); |
| 133 | + this.hasAnyTrainedClass = true; |
| 134 | + } else { |
| 135 | + const oldX = this.xs; |
| 136 | + this.xs = tf.keep(oldX.concat(prediction, 0)); |
| 137 | + const oldY = this.ys; |
| 138 | + this.ys = tf.keep(oldY.concat(y, 0)); |
| 139 | + oldX.dispose(); |
| 140 | + oldY.dispose(); |
| 141 | + y.dispose(); |
| 142 | + } |
| 143 | + }); |
| 144 | + if (callback) { |
| 145 | + callback(); |
| 146 | + } |
| 147 | + } else { |
| 148 | + console.warn('The model is not loaded yet.'); |
| 149 | + } |
| 150 | + } |
| 151 | + |
| 152 | + async train(onProgress) { |
| 153 | + if (!this.hasAnyTrainedClass) { |
| 154 | + throw new Error('Add some examples before training!'); |
| 155 | + } |
| 156 | + |
| 157 | + this.isPredicting = false; |
| 158 | + |
| 159 | + if (this.usageType === 'classifier') { |
| 160 | + this.loss = 'categoricalCrossentropy'; |
| 161 | + this.customModel = tf.sequential({ |
| 162 | + layers: [ |
| 163 | + tf.layers.flatten({ inputShape: [7, 7, 256] }), |
| 164 | + tf.layers.dense({ |
| 165 | + units: this.hiddenUnits, |
| 166 | + activation: 'relu', |
| 167 | + kernelInitializer: 'varianceScaling', |
| 168 | + useBias: true, |
| 169 | + }), |
| 170 | + tf.layers.dense({ |
| 171 | + units: this.numClasses, |
| 172 | + kernelInitializer: 'varianceScaling', |
| 173 | + useBias: false, |
| 174 | + activation: 'softmax', |
| 175 | + }), |
| 176 | + ], |
| 177 | + }); |
| 178 | + } else if (this.usageType === 'regressor') { |
| 179 | + this.loss = 'meanSquaredError'; |
| 180 | + this.customModel = tf.sequential({ |
| 181 | + layers: [ |
| 182 | + tf.layers.flatten({ inputShape: [7, 7, 256] }), |
| 183 | + tf.layers.dense({ |
| 184 | + units: this.hiddenUnits, |
| 185 | + activation: 'relu', |
| 186 | + kernelInitializer: 'varianceScaling', |
| 187 | + useBias: true, |
| 188 | + }), |
| 189 | + tf.layers.dense({ |
| 190 | + units: 1, |
| 191 | + useBias: false, |
| 192 | + kernelInitializer: 'Zeros', |
| 193 | + activation: 'linear', |
| 194 | + }), |
| 195 | + ], |
| 196 | + }); |
| 197 | + } |
| 198 | + |
| 199 | + const optimizer = tf.train.adam(this.learningRate); |
| 200 | + this.customModel.compile({ optimizer, loss: this.loss }); |
| 201 | + const batchSize = Math.floor(this.xs.shape[0] * this.batchSize); |
| 202 | + if (!(batchSize > 0)) { |
| 203 | + throw new Error('Batch size is 0 or NaN. Please choose a non-zero fraction.'); |
| 204 | + } |
| 205 | + |
| 206 | + this.customModel.fit(this.xs, this.ys, { |
| 207 | + batchSize, |
| 208 | + epochs: this.epochs, |
| 209 | + callbacks: { |
| 210 | + onBatchEnd: async (batch, logs) => { |
| 211 | + onProgress(logs.loss.toFixed(5)); |
| 212 | + await tf.nextFrame(); |
| 213 | + }, |
| 214 | + onTrainEnd: () => onProgress(null), |
| 215 | + }, |
| 216 | + }); |
| 217 | + } |
| 218 | + |
| 219 | + /* eslint max-len: ["error", { "code": 180 }] */ |
| 220 | + async classify(inputOrCallback, cb = null) { |
| 221 | + if (this.usageType === 'classifier') { |
| 222 | + let imgToPredict; |
| 223 | + let callback; |
| 224 | + |
| 225 | + if (inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement) { |
| 226 | + imgToPredict = inputOrCallback; |
| 227 | + } else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLImageElement || inputOrCallback.elt instanceof HTMLVideoElement)) { |
| 228 | + imgToPredict = inputOrCallback.elt; // p5.js image element |
| 229 | + } else if (typeof inputOrCallback === 'function') { |
| 230 | + imgToPredict = this.video; |
| 231 | + callback = inputOrCallback; |
| 232 | + } |
| 233 | + |
| 234 | + if (typeof cb === 'function') { |
| 235 | + callback = cb; |
| 236 | + } |
| 237 | + |
| 238 | + this.isPredicting = true; |
| 239 | + const predictedClass = tf.tidy(() => { |
| 240 | + const processedImg = imgToTensor(imgToPredict); |
| 241 | + const activation = this.mobilenetFeatures.predict(processedImg); |
| 242 | + const predictions = this.customModel.predict(activation); |
| 243 | + return predictions.as1D().argMax(); |
| 244 | + }); |
| 245 | + let classId = (await predictedClass.data())[0]; |
| 246 | + await tf.nextFrame(); |
| 247 | + if (callback) { |
| 248 | + if (this.mapStringToIndex.length > 0) { |
| 249 | + classId = this.mapStringToIndex[classId]; |
| 250 | + } |
| 251 | + callback(classId); |
| 252 | + } |
| 253 | + } else { |
| 254 | + console.warn('Mobilenet Feature Extraction has not been set to be a classifier.'); |
| 255 | + } |
| 256 | + } |
| 257 | + |
| 258 | + /* eslint max-len: ["error", { "code": 180 }] */ |
| 259 | + async predict(inputOrCallback, cb = null) { |
| 260 | + if (this.usageType === 'regressor') { |
| 261 | + let imgToPredict; |
| 262 | + let callback; |
| 263 | + |
| 264 | + if (inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement) { |
| 265 | + imgToPredict = inputOrCallback; |
| 266 | + } else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLImageElement || inputOrCallback.elt instanceof HTMLVideoElement)) { |
| 267 | + imgToPredict = inputOrCallback.elt; // p5.js image element |
| 268 | + } else if (typeof inputOrCallback === 'function') { |
| 269 | + imgToPredict = this.video; |
| 270 | + callback = inputOrCallback; |
| 271 | + } |
| 272 | + |
| 273 | + if (typeof cb === 'function') { |
| 274 | + callback = cb; |
| 275 | + } |
| 276 | + |
| 277 | + this.isPredicting = true; |
| 278 | + const predictedClass = tf.tidy(() => { |
| 279 | + const processedImg = imgToTensor(imgToPredict); |
| 280 | + const activation = this.mobilenetFeatures.predict(processedImg); |
| 281 | + const predictions = this.customModel.predict(activation); |
| 282 | + return predictions.as1D(); |
| 283 | + }); |
| 284 | + const prediction = (await predictedClass.data()); |
| 285 | + predictedClass.dispose(); |
| 286 | + await tf.nextFrame(); |
| 287 | + if (callback) { |
| 288 | + callback(prediction[0]); |
| 289 | + } |
| 290 | + } else { |
| 291 | + console.warn('Mobilenet Feature Extraction has not been set to be a regressor.'); |
| 292 | + } |
| 293 | + } |
| 294 | + |
| 295 | + // Static Method: get top k classes for mobilenet |
| 296 | + static async getTopKClasses(logits, topK, callback) { |
| 297 | + const values = await logits.data(); |
| 298 | + const valuesAndIndices = []; |
| 299 | + for (let i = 0; i < values.length; i += 1) { |
| 300 | + valuesAndIndices.push({ value: values[i], index: i }); |
| 301 | + } |
| 302 | + valuesAndIndices.sort((a, b) => b.value - a.value); |
| 303 | + const topkValues = new Float32Array(topK); |
| 304 | + |
| 305 | + const topkIndices = new Int32Array(topK); |
| 306 | + for (let i = 0; i < topK; i += 1) { |
| 307 | + topkValues[i] = valuesAndIndices[i].value; |
| 308 | + topkIndices[i] = valuesAndIndices[i].index; |
| 309 | + } |
| 310 | + const topClassesAndProbs = []; |
| 311 | + for (let i = 0; i < topkIndices.length; i += 1) { |
| 312 | + topClassesAndProbs.push({ |
| 313 | + className: IMAGENET_CLASSES[topkIndices[i]], |
| 314 | + probability: topkValues[i], |
| 315 | + }); |
| 316 | + } |
| 317 | + |
| 318 | + await tf.nextFrame(); |
| 319 | + |
| 320 | + if (callback) { |
| 321 | + callback(topClassesAndProbs); |
| 322 | + } |
| 323 | + return topClassesAndProbs; |
13 | 324 | }
|
14 | 325 | }
|
15 | 326 |
|
16 |
| -export default FeatureExtractor; |
| 327 | +export default Mobilenet; |
0 commit comments