Skip to content

Commit f75bc83

Browse files
yining1023joeyklee
authored andcommitted
Update ImageClassifier and FeatureExtractor to return [{label, confidence}] (#292)
* change imageClassifier.predict() to classify(), keep predict until later version * update imageclassifier.classify results to [{label: .., confidence: ..}] * update featureExtractor regressor return {value: 0-1} * update featureextractor.classify to return [{label, confidence}]
1 parent 87a9100 commit f75bc83

File tree

3 files changed

+32
-23
lines changed

3 files changed

+32
-23
lines changed

src/FeatureExtractor/Mobilenet.js

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -247,18 +247,21 @@ class Mobilenet {
247247
}
248248
await tf.nextFrame();
249249
this.isPredicting = true;
250-
const predictedClass = tf.tidy(() => {
250+
const predictedClasses = tf.tidy(() => {
251251
const imageResize = (imgToPredict === this.video) ? null : [IMAGE_SIZE, IMAGE_SIZE];
252252
const processedImg = imgToTensor(imgToPredict, imageResize);
253253
const activation = this.mobilenetFeatures.predict(processedImg);
254254
const predictions = this.customModel.predict(activation);
255-
return predictions.as1D().argMax();
255+
return Array.from(predictions.as1D().dataSync());
256256
});
257-
let classId = (await predictedClass.data())[0];
258-
if (this.mapStringToIndex.length > 0) {
259-
classId = this.mapStringToIndex[classId];
260-
}
261-
return classId;
257+
const results = await predictedClasses.map((confidence, index) => {
258+
const label = (this.mapStringToIndex.length > 0 && this.mapStringToIndex[index]) ? this.mapStringToIndex[index] : index;
259+
return {
260+
label,
261+
confidence,
262+
};
263+
}).sort((a, b) => b.confidence - a.confidence);
264+
return results;
262265
}
263266

264267
/* eslint max-len: ["error", { "code": 180 }] */
@@ -295,7 +298,7 @@ class Mobilenet {
295298
});
296299
const prediction = await predictedClass.data();
297300
predictedClass.dispose();
298-
return prediction[0];
301+
return { value: prediction[0] };
299302
}
300303

301304
async load(filesOrPath = null, callback) {

src/ImageClassifier/index.js

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class ImageClassifier {
5252
return this;
5353
}
5454

55-
async predictInternal(imgToPredict, numberOfClasses) {
55+
async classifyInternal(imgToPredict, numberOfClasses) {
5656
// Wait for the model to be ready
5757
await this.ready;
5858
await tf.nextFrame();
@@ -62,10 +62,12 @@ class ImageClassifier {
6262
this.video.onloadeddata = () => resolve();
6363
});
6464
}
65-
return this.model.classify(imgToPredict, numberOfClasses);
65+
return this.model
66+
.classify(imgToPredict, numberOfClasses)
67+
.then(classes => classes.map(c => ({ label: c.className, confidence: c.probability })));
6668
}
6769

68-
async predict(inputNumOrCallback, numOrCallback = null, cb) {
70+
async classify(inputNumOrCallback, numOrCallback = null, cb) {
6971
let imgToPredict = this.video;
7072
let numberOfClasses = this.topk;
7173
let callback;
@@ -102,7 +104,11 @@ class ImageClassifier {
102104
callback = cb;
103105
}
104106

105-
return callCallback(this.predictInternal(imgToPredict, numberOfClasses), callback);
107+
return callCallback(this.classifyInternal(imgToPredict, numberOfClasses), callback);
108+
}
109+
110+
async predict(inputNumOrCallback, numOrCallback, cb) {
111+
return this.classify(inputNumOrCallback, numOrCallback || null, cb);
106112
}
107113
}
108114

src/ImageClassifier/index_test.js

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,35 +48,35 @@ describe('imageClassifier', () => {
4848
expect(classifier.ready).toBeTruthy();
4949
});
5050

51-
describe('predict', () => {
51+
describe('classify', () => {
5252
it('Should classify an image of a Robin', async () => {
5353
const img = await getImage();
54-
await classifier.predict(img)
55-
.then(results => expect(results[0].className).toBe('robin, American robin, Turdus migratorius'));
54+
await classifier.classify(img)
55+
.then(results => expect(results[0].label).toBe('robin, American robin, Turdus migratorius'));
5656
});
5757

5858
it('Should support p5 elements with an image on .elt', async () => {
5959
const img = await getImage();
60-
await classifier.predict({ elt: img })
61-
.then(results => expect(results[0].className).toBe('robin, American robin, Turdus migratorius'));
60+
await classifier.classify({ elt: img })
61+
.then(results => expect(results[0].label).toBe('robin, American robin, Turdus migratorius'));
6262
});
6363

6464
it('Should support HTMLCanvasElement', async () => {
6565
const canvas = await getCanvas();
66-
await classifier.predict(canvas)
67-
.then(results => expect(results[0].className).toBe('robin, American robin, Turdus migratorius'));
66+
await classifier.classify(canvas)
67+
.then(results => expect(results[0].label).toBe('robin, American robin, Turdus migratorius'));
6868
});
6969

7070
it('Should support p5 elements with canvas on .canvas', async () => {
7171
const canvas = await getCanvas();
72-
await classifier.predict({ canvas })
73-
.then(results => expect(results[0].className).toBe('robin, American robin, Turdus migratorius'));
72+
await classifier.classify({ canvas })
73+
.then(results => expect(results[0].label).toBe('robin, American robin, Turdus migratorius'));
7474
});
7575

7676
it('Should support p5 elements with canvas on .elt', async () => {
7777
const canvas = await getCanvas();
78-
await classifier.predict({ elt: canvas })
79-
.then(results => expect(results[0].className).toBe('robin, American robin, Turdus migratorius'));
78+
await classifier.classify({ elt: canvas })
79+
.then(results => expect(results[0].label).toBe('robin, American robin, Turdus migratorius'));
8080
});
8181
});
8282
});

0 commit comments

Comments
 (0)