Skip to content

Commit 1b767a4

Browse files
committed
updated the preproccessing/postproccessing functions
1 parent f7695a5 commit 1b767a4

File tree

2 files changed

+83
-71
lines changed

2 files changed

+83
-71
lines changed

src/YOLO/index.js

Lines changed: 78 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,17 @@ const DEFAULTS = {
1616
IOUThreshold: 0.4,
1717
classProbThreshold: 0.4,
1818
URL: 'https://raw.githubusercontent.com/ml5js/ml5-library/master/src/YOLO/model.json',
19-
19+
imageSize: 416,
2020
};
2121

22-
class YOLO {
22+
class YOLOBase {
2323
constructor(options) {
2424
this.filterBoxesThreshold = options.filterBoxesThreshold || DEFAULTS.filterBoxesThreshold;
2525
this.IOUThreshold = options.IOUThreshold || DEFAULTS.IOUThreshold;
2626
this.classProbThreshold = options.classProbThreshold || DEFAULTS.classProbThreshold;
2727
this.modelURL = options.url || DEFAULTS.URL;
2828
this.model = null;
29-
this.inputWidth = 416;
30-
this.inputHeight = 416;
29+
this.imageSize = options.imageSize || DEFAULTS.imageSize;
3130
this.classNames = CLASS_NAMES;
3231
this.anchors = [
3332
[0.57273, 0.677385],
@@ -36,8 +35,6 @@ class YOLO {
3635
[7.88282, 3.52778],
3736
[9.77052, 9.16828],
3837
];
39-
// this.scaleX;
40-
// this.scaleY;
4138
this.anchorsLength = this.anchors.length;
4239
this.classesLength = this.classNames.length;
4340
this.init();
@@ -71,11 +68,16 @@ class YOLO {
7168
});
7269
}
7370

74-
// takes HTMLCanvasElement || HTMLImageElement ||HTMLVideoElement || ImageData as input
75-
// outs results obj
76-
async detect(input) {
71+
/**
72+
* Infers through the model.
73+
* TODO : Optionally takes an endpoint to return an intermediate activation.
74+
* @param img The image to classify. Can be a tensor or a DOM element image,
75+
* video, or canvas.
76+
* img: tf.Tensor3D|ImageData|HTMLImageElement|HTMLCanvasElement|HTMLVideoElement
77+
*/
78+
async detect(img) {
7779
const predictions = tf.tidy(() => {
78-
const data = this.preProccess(input);
80+
const data = this.preProccess(img);
7981
const preds = this.model.predict(data);
8082
return preds;
8183
});
@@ -99,63 +101,73 @@ class YOLO {
99101
tf.disposeconstiables();
100102
}
101103

102-
// should be called after loadModel()
103-
cache() {
104-
tf.tidy(() => {
105-
const dummy = tf.zeros([0, 416, 416, 3]);
106-
this.model.predict(dummy);
107-
});
104+
// should be called after load()
105+
async cache() {
106+
const dummy = tf.zeros([416, 416, 3]);
107+
await this.detect(dummy);
108+
dummy.dispose();
108109
}
109110

110-
preProccess(input) {
111-
const img = tf.fromPixels(input);
112-
const [w, h] = [img.shape[1], img.shape[0]];
113-
this.imgWidth = w;
114-
this.imgHeight = h;
111+
preProccess(img) {
112+
let image;
113+
if (!(img instanceof tf.Tensor)) {
114+
if (img instanceof HTMLImageElement || img instanceof HTMLVideoElement) {
115+
image = tf.fromPixels(img);
116+
} else if (typeof img === 'object' && (img.elt instanceof HTMLImageElement || img.elt instanceof HTMLVideoElement)) {
117+
image = tf.fromPixels(img.elt); // Handle p5.js image and video.
118+
}
119+
} else {
120+
image = img;
121+
}
115122

116-
const img1 = tf.image.resizeBilinear(img, [this.inputHeight, this.inputWidth]).toFloat().div(tf.scalar(255)).expandDims(0);
123+
[this.imgWidth, this.imgHeight] = [image.shape[1], image.shape[0]];
117124

125+
// Normalize the image from [0, 255] to [0, 1].
126+
const normalized = image.toFloat().div(tf.scalar(255));
127+
let resized = normalized;
128+
if (normalized.shape[0] !== this.imageSize || normalized.shape[1] !== this.imageSize) {
129+
const alignCorners = true;
130+
resized = tf.image.resizeBilinear(normalized, [this.imageSize, this.imageSize], alignCorners);
131+
}
132+
// Reshape to a single-element batch so we can pass it to predict.
133+
const batched = resized.reshape([1, this.imageSize, this.imageSize, 3]);
118134
// Scale Stuff
119-
this.scaleX = this.imgHeight / this.inputHeight;
120-
this.scaleY = this.imgWidth / this.inputWidth;
121-
return img1;
135+
// this.scaleX = this.imgHeight / this.inputHeight;
136+
// this.scaleY = this.imgWidth / this.inputWidth;
137+
return batched;
122138
}
123139

140+
141+
/**
142+
* postproccessing for the yolo output
143+
* TODO : make this more modular in preperation for yolov3-tiny
144+
* @param rawPrediction a 4D tensor 13*13*425
145+
*/
124146
async postProccess(rawPrediction) {
125-
const results = {
126-
totalDetections: 0,
127-
detections: [],
128-
};
129147
const [boxes, boxScores, classes, Indices] = tf.tidy(() => {
130-
const rawPrediction1 = tf.reshape(rawPrediction, [13, 13, this.anchorsLength, this.classesLength + 5]);
148+
const reshaped = tf.reshape(rawPrediction, [13, 13, this.anchorsLength, this.classesLength + 5]);
131149
// Box Coords
132-
const boxxy = tf.sigmoid(rawPrediction1.slice([0, 0, 0, 0], [13, 13, this.anchorsLength, 2]));
133-
const boxwh = tf.exp(rawPrediction1.slice([0, 0, 0, 2], [13, 13, this.anchorsLength, 2]));
150+
const boxxy = tf.sigmoid(reshaped.slice([0, 0, 0, 0], [13, 13, this.anchorsLength, 2]));
151+
const boxwh = tf.exp(reshaped.slice([0, 0, 0, 2], [13, 13, this.anchorsLength, 2]));
134152
// ObjectnessScore
135-
const boxConfidence = tf.sigmoid(rawPrediction1.slice([0, 0, 0, 4], [13, 13, this.anchorsLength, 1]));
153+
const boxConfidence = tf.sigmoid(reshaped.slice([0, 0, 0, 4], [13, 13, this.anchorsLength, 1]));
136154
// ClassProb
137-
const boxClassProbs = tf.softmax(rawPrediction1.slice([0, 0, 0, 5], [13, 13, this.anchorsLength, this.classesLength]));
155+
const boxClassProbs = tf.softmax(reshaped.slice([0, 0, 0, 5], [13, 13, this.anchorsLength, this.classesLength]));
138156

139157
// from boxes with xy wh to x1,y1 x2,y2
158+
// xy:bounding box center wh:width/Height
140159
// Mainly for NMS + rescaling
141-
/*
142-
x1 = x + (h/2)
143-
y1 = y - (w/2)
144-
x2 = x - (h/2)
145-
y2 = y + (w/2)
146-
*/
147-
// BoxScale
148-
const boxXY1 = tf.div(tf.add(boxxy, this.ConvIndex), this.ConvDims);
149-
150-
const boxWH1 = tf.div(tf.mul(boxwh, this.AnchorsTensor), this.ConvDims);
151-
152-
const Div = tf.div(boxWH1, tf.scalar(2));
153-
154-
const boxMins = tf.sub(boxXY1, Div);
155-
const boxMaxes = tf.add(boxXY1, Div);
156-
160+
// x1 = x + (h/2)
161+
// y1 = y - (w/2)
162+
// x2 = x - (h/2)
163+
// y2 = y + (w/2)
164+
165+
const boxxy1 = tf.div(tf.add(boxxy, this.ConvIndex), this.ConvDims);
166+
const boxwh1 = tf.div(tf.mul(boxwh, this.AnchorsTensor), this.ConvDims);
167+
const div = tf.div(boxwh1, tf.scalar(2));
168+
const boxMins = tf.sub(boxxy1, div);
169+
const boxMaxes = tf.add(boxxy1, div);
157170
const size = [boxMins.shape[0], boxMins.shape[1], boxMins.shape[2], 1];
158-
159171
// main box tensor
160172
const finalboxes = tf.concat([
161173
boxMins.slice([0, 0, 0, 1], size),
@@ -166,44 +178,48 @@ class YOLO {
166178

167179
// Filterboxes by objectness threshold
168180
// not filtering / getting a mask really
169-
170181
const boxConfidence1 = boxConfidence.squeeze([3]);
171182
const objectnessMask = tf.greaterEqual(boxConfidence1, tf.scalar(this.filterBoxesThreshold));
172183

173184
// Filterboxes by class probability threshold
174185
const boxScores1 = tf.mul(boxConfidence1, tf.max(boxClassProbs, 3));
175186
const boxClassProbMask = tf.greaterEqual(boxScores1, tf.scalar(this.classProbThreshold));
176187

177-
// getting classes indices
188+
// getting classes indices
178189
const classes1 = tf.argMax(boxClassProbs, -1);
179190

180-
// Final Mask each elem that survived both filters (0x0 0x1 1x0 = fail ) 1x1 = survived
191+
// Final Mask each elem that survived both filters (0x0 0x1 1x0 = fail ) 1x1 = survived
181192
const finalMask = boxClassProbMask.mul(objectnessMask);
182193

183194
const indices = finalMask.flatten().toInt().mul(this.indicesTensor);
184195
return [finalboxes, boxScores1, classes1, indices];
185196
});
186197

187198
// we started at one in the range so we remove 1 now
199+
// this is where a major bottleneck happens
200+
// this can be replaced with tf.boolean_mask() if tfjs team implements it
201+
// thisis also why wehave 2 tf.tidy()'s
202+
// more info : https://github.com/ModelDepot/tfjs-yolo-tiny/issues/6
188203

189204
const indicesArr = Array.from(await Indices.data()).filter(i => i > 0).map(i => i - 1);
190205

191206
if (indicesArr.length === 0) {
192207
boxes.dispose();
193208
boxScores.dispose();
194209
classes.dispose();
195-
return results;
210+
return [];
196211
}
212+
197213
const [filteredBoxes, filteredScores, filteredclasses] = tf.tidy(() => {
198214
const indicesTensor = tf.tensor1d(indicesArr, 'int32');
199215
const filteredBoxes1 = boxes.gather(indicesTensor);
200216
const filteredScores1 = boxScores.flatten().gather(indicesTensor);
201217
const filteredclasses1 = classes.flatten().gather(indicesTensor);
202-
// Img Rescale
218+
// Image Rescale
203219
const Height = tf.scalar(this.imgHeight);
204220
const Width = tf.scalar(this.imgWidth);
205-
// 4
206221
const ImageDims = tf.stack([Height, Width, Height, Width]).reshape([1, 4]);
222+
207223
const filteredBoxes2 = filteredBoxes1.mul(ImageDims);
208224
return [filteredBoxes2, filteredScores1, filteredclasses1];
209225
});
@@ -240,17 +256,16 @@ class YOLO {
240256
});
241257

242258
// final phase
243-
259+
const detections = [];
244260
// add any output you want
245261
for (let id = 0; id < selectedBoxes.length; id += 1) {
246262
const classProb = selectedBoxes[id][0];
247263
const classProbRounded = Math.round(classProb * 1000) / 10;
248264
const className = this.classNames[selectedBoxes[id][2]];
249265
const classIndex = selectedBoxes[id][2];
250266
const [y1, x1, y2, x2] = selectedBoxes[id][1];
251-
// Need to get this out
252267
// TODO : add a hsla color for later visualization
253-
const resultObj = {
268+
const detection = {
254269
id,
255270
className,
256271
classIndex,
@@ -261,14 +276,11 @@ class YOLO {
261276
x2,
262277
y2,
263278
};
264-
results.detections.push(resultObj);
279+
detections.push(detection);
265280
}
266-
// Misc
267-
results.totalDetections = results.detections.length;
268-
results.scaleX = this.scaleX;
269-
results.scaleY = this.scaleY;
270-
return results;
281+
return detections;
271282
}
272283
}
273284

285+
const YOLO = options => new YOLOBase(options);
274286
export default YOLO;

src/YOLO/index_test.js

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
// This software is released under the MIT License.
44
// https://opensource.org/licenses/MIT
55

6-
const { YOLO } = ml5;
6+
const { tf, YOLO } = ml5;
77

88
const YOLO_DEFAULTS = {
9+
filterBoxesThreshold: 0.01,
910
IOUThreshold: 0.4,
1011
classProbThreshold: 0.4,
11-
filterBoxesThreshold: 0.01,
12-
size: 416,
12+
imageSize: 416,
1313
};
1414

1515
describe('YOLO', () => {
@@ -25,14 +25,14 @@ describe('YOLO', () => {
2525

2626
beforeEach(async () => {
2727
jasmine.DEFAULT_TIMEOUT_INTERVAL = 100000;
28-
yolo = await YOLO();
28+
yolo = new YOLO();
2929
});
3030

3131
it('instantiates the YOLO classifier with defaults', () => {
3232
expect(yolo.IOUThreshold).toBe(YOLO_DEFAULTS.IOUThreshold);
3333
expect(yolo.classProbThreshold).toBe(YOLO_DEFAULTS.classProbThreshold);
3434
expect(yolo.filterBoxesThreshold).toBe(YOLO_DEFAULTS.filterBoxesThreshold);
35-
expect(yolo.size).toBe(YOLO_DEFAULTS.size);
35+
expect(yolo.imageSize).toBe(YOLO_DEFAULTS.imageSize);
3636
});
3737

3838
it('detects a robin', async () => {

0 commit comments

Comments
 (0)