Skip to content

Commit e559690

Browse files
hiddentncvalenzuela
authored andcommitted
added darknet reference and darknet tiny classifiers (#201)
* added darknet reference and darknet tiny classifiers * fixed linting stuff * edits.. * edits2.0 * clean comments
1 parent d441c34 commit e559690

File tree

3 files changed

+1144
-7
lines changed

3 files changed

+1144
-7
lines changed

src/ImageClassifier/darknet.js

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright (c) 2018 ml5
2+
//
3+
// This software is released under the MIT License.
4+
// https://opensource.org/licenses/MIT
5+
6+
import * as tf from '@tensorflow/tfjs';
7+
import IMAGENET_CLASSES_DARKNET from '../utils/IMAGENET_CLASSES_DARKNET';
8+
9+
const DEFAULTS = {
10+
DARKNET_URL: 'https://rawgit.com/ml5js/ml5-data-and-models/master/models/darknetclassifier/darknetreference/model.json',
11+
DARKNET_TINY_URL: 'https://rawgit.com/ml5js/ml5-data-and-models/master/models/darknetclassifier/darknettiny/model.json',
12+
IMAGE_SIZE_DARKNET: 256,
13+
IMAGE_SIZE_DARKNET_TINY: 224,
14+
};
15+
16+
async function getTopKClasses(logits, topK) {
17+
const values = await logits.data();
18+
const valuesAndIndices = [];
19+
for (let i = 0; i < values.length; i += 1) {
20+
valuesAndIndices.push({
21+
value: values[i],
22+
index: i,
23+
});
24+
}
25+
valuesAndIndices.sort((a, b) => b.value - a.value);
26+
27+
const topkValues = new Float32Array(topK);
28+
const topkIndices = new Int32Array(topK);
29+
for (let i = 0; i < topK; i += 1) {
30+
topkValues[i] = valuesAndIndices[i].value;
31+
topkIndices[i] = valuesAndIndices[i].index;
32+
}
33+
34+
const topClassesAndProbs = [];
35+
for (let i = 0; i < topkIndices.length; i += 1) {
36+
topClassesAndProbs.push({
37+
className: IMAGENET_CLASSES_DARKNET[topkIndices[i]],
38+
probability: topkValues[i],
39+
});
40+
}
41+
return topClassesAndProbs;
42+
}
43+
44+
function preProcess(img, size) {
45+
let image;
46+
if (!(img instanceof tf.Tensor)) {
47+
if (img instanceof HTMLImageElement || img instanceof HTMLVideoElement) {
48+
image = tf.fromPixels(img);
49+
} else if (typeof img === 'object' && (img.elt instanceof HTMLImageElement || img.elt instanceof HTMLVideoElement)) {
50+
image = tf.fromPixels(img.elt); // Handle p5.js image and video.
51+
}
52+
} else {
53+
image = img;
54+
}
55+
const normalized = image.toFloat().div(tf.scalar(255));
56+
let resized = normalized;
57+
if (normalized.shape[0] !== size || normalized.shape[1] !== size) {
58+
const alignCorners = true;
59+
resized = tf.image.resizeBilinear(normalized, [size, size], alignCorners);
60+
}
61+
const batched = resized.reshape([1, size, size, 3]);
62+
return batched;
63+
}
64+
65+
export class Darknet {
66+
constructor(version) {
67+
this.version = version;
68+
switch (this.version) {
69+
case 'reference':
70+
this.imgSize = DEFAULTS.IMAGE_SIZE_DARKNET;
71+
break;
72+
case 'tiny':
73+
this.imgSize = DEFAULTS.IMAGE_SIZE_DARKNET_TINY;
74+
break;
75+
default:
76+
break;
77+
}
78+
}
79+
80+
async load() {
81+
switch (this.version) {
82+
case 'reference':
83+
this.model = await tf.loadModel(DEFAULTS.DARKNET_URL);
84+
break;
85+
case 'tiny':
86+
this.model = await tf.loadModel(DEFAULTS.DARKNET_TINY_URL);
87+
break;
88+
default:
89+
break;
90+
}
91+
92+
// Warmup the model.
93+
const result = tf.tidy(() => this.model.predict(tf.zeros([1, this.imgSize, this.imgSize, 3])));
94+
await result.data();
95+
result.dispose();
96+
}
97+
98+
async classify(img, topk = 3) {
99+
const logits = tf.tidy(() => {
100+
const imgData = preProcess(img, this.imgSize);
101+
const predictions = this.model.predict(imgData);
102+
return tf.softmax(predictions);
103+
});
104+
const classes = await getTopKClasses(logits, topk);
105+
logits.dispose();
106+
return classes;
107+
}
108+
}
109+
110+
export async function load(version) {
111+
if (version !== 'reference' && version !== 'tiny') {
112+
throw new Error('Please select a version: darknet-reference or darknet-tiny');
113+
}
114+
115+
const darknet = new Darknet(version);
116+
await darknet.load();
117+
return darknet;
118+
}

src/ImageClassifier/index.js

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Image Classifier using pre-trained networks
99

1010
import * as tf from '@tensorflow/tfjs';
1111
import * as mobilenet from '@tensorflow-models/mobilenet';
12+
import * as darknet from './darknet';
1213
import callCallback from '../utils/callcallback';
1314

1415
const DEFAULTS = {
@@ -23,14 +24,24 @@ class ImageClassifier {
2324
constructor(modelName, video, options, callback) {
2425
this.modelName = modelName;
2526
this.video = video;
26-
this.version = options.version || DEFAULTS[this.modelName].version;
27-
this.alpha = options.alpha || DEFAULTS[this.modelName].alpha;
28-
this.topk = options.topk || DEFAULTS[this.modelName].topk;
2927
this.model = null;
30-
if (this.modelName === 'mobilenet') {
31-
this.modelToUse = mobilenet;
32-
} else {
33-
this.modelToUse = null;
28+
switch (this.modelName) {
29+
case 'mobilenet':
30+
this.modelToUse = mobilenet;
31+
this.version = options.version || DEFAULTS.mobilenet.version;
32+
this.alpha = options.alpha || DEFAULTS.mobilenet.alpha;
33+
this.topk = options.topk || DEFAULTS.mobilenet.topk;
34+
break;
35+
case 'darknet':
36+
this.version = 'reference'; // this a 28mb model
37+
this.modelToUse = darknet;
38+
break;
39+
case 'darknet-tiny':
40+
this.version = 'tiny'; // this a 4mb model
41+
this.modelToUse = darknet;
42+
break;
43+
default:
44+
this.modelToUse = null;
3445
}
3546
// Load the model
3647
this.ready = callCallback(this.loadModel(), callback);

0 commit comments

Comments
 (0)