Skip to content

Commit bb23209

Browse files
committed
fixes some stuff
1 parent 531e21f commit bb23209

File tree

2 files changed

+42
-36
lines changed

2 files changed

+42
-36
lines changed

src/YOLO/index.js

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@ const DEFAULTS = {
2020
};
2121

2222
class YOLOBase {
23-
constructor(options) {
23+
constructor(options = DEFAULTS) {
2424
this.filterBoxesThreshold = options.filterBoxesThreshold || DEFAULTS.filterBoxesThreshold;
2525
this.IOUThreshold = options.IOUThreshold || DEFAULTS.IOUThreshold;
2626
this.classProbThreshold = options.classProbThreshold || DEFAULTS.classProbThreshold;
2727
this.modelURL = options.url || DEFAULTS.URL;
28-
this.model = null;
2928
this.imageSize = options.imageSize || DEFAULTS.imageSize;
3029
this.classNames = CLASS_NAMES;
3130
this.anchors = [
@@ -35,20 +34,30 @@ class YOLOBase {
3534
[7.88282, 3.52778],
3635
[9.77052, 9.16828],
3736
];
38-
this.anchorsLength = this.anchors.length;
39-
this.classesLength = this.classNames.length;
4037
this.init();
4138
}
4239

4340
init() {
44-
const Aten = tf.tensor2d(this.anchors);
45-
this.anchorsTensor = tf.reshape(Aten, [1, 1, this.anchorsLength, 2]);
46-
Aten.dispose();
41+
const outputWidth = 13;
42+
const outputHeight = 13;
43+
44+
[this.convIndex, this.convDims, this.anchorsTensor] = tf.tidy(() => {
45+
const Atensor = tf.tensor2d(this.anchors);
46+
const anchorsTensor = tf.reshape(Atensor, [1, 1, Atensor.shape[0], 2]);
47+
48+
let convIndex = tf.range(0, outputWidth);
49+
const convHeightIndex = tf.tile(convIndex, [outputHeight]);
50+
let convWidthindex = tf.tile(tf.expandDims(convIndex, 0), [outputWidth, 1]);
51+
convWidthindex = tf.transpose(convWidthindex).flatten();
52+
convIndex = tf.transpose(tf.stack([convHeightIndex, convWidthindex]));
53+
convIndex = tf.reshape(convIndex, [outputWidth, outputHeight, 1, 2]);
54+
const convDims = tf.reshape(tf.tensor1d([outputWidth, outputHeight]), [1, 1, 1, 2]);
55+
return [convIndex, convDims, anchorsTensor];
56+
});
4757
}
4858

4959
/**
5060
* Infers through the model.
51-
* TODO : Optionally takes an endpoint to return an intermediate activation.
5261
* @param img The image to classify. Can be a tensor or a DOM element image,
5362
* video, or canvas.
5463
* img: tf.Tensor3D|ImageData|HTMLImageElement|HTMLCanvasElement|HTMLVideoElement
@@ -60,6 +69,7 @@ class YOLOBase {
6069
return preds;
6170
});
6271
const results = await this.postProcess(predictions);
72+
predictions.dispose();
6373
return results;
6474
}
6575

@@ -122,8 +132,7 @@ class YOLOBase {
122132
* @param rawPrediction a 4D tensor
123133
*/
124134
async postProcess(rawPrediction) {
125-
const [boxes, boxScores, classes, Indices] = tf.tidy(() => this.split(rawPrediction.squeeze([0]), this.anchorsTensor));
126-
135+
const [boxes, boxScores, classes, Indices] = tf.tidy(() => this.split(rawPrediction.squeeze([0])));
127136
// we started at one in the range so we remove 1 now
128137
const indicesArr = Array.from(await Indices.data()).filter(i => i > 0).map(i => i - 1);
129138

@@ -148,6 +157,7 @@ class YOLOBase {
148157
// this for x y w h
149158
const ImageDims = tf.stack([Width, Height, Width, Height]).reshape([1, 4]);
150159
const filteredBoxes2 = filteredBoxes1.mul(ImageDims);
160+
151161
return [filteredBoxes2, filteredScores1, filteredclasses1];
152162
});
153163
boxes.dispose();
@@ -218,46 +228,42 @@ class YOLOBase {
218228
return detections;
219229
}
220230

221-
split(rawPrediction, AnchorsTensor) {
231+
split(rawPrediction) {
222232
const [outputWidth, outputHeight] = [rawPrediction.shape[0], rawPrediction.shape[1]];
223-
const reshaped = tf.reshape(rawPrediction, [outputWidth, outputHeight, this.anchorsLength, this.classesLength + 5]);
224-
// Box xywh
225-
const boxxy = tf.sigmoid(reshaped.slice([0, 0, 0, 0], [outputWidth, outputHeight, this.anchorsLength, 2]));
226-
const boxwh = tf.exp(reshaped.slice([0, 0, 0, 2], [outputWidth, outputHeight, this.anchorsLength, 2]));
233+
const anchorsLength = this.anchorsTensor.shape[2];
234+
const classesLength = this.classNames.length;
235+
const reshaped = tf.reshape(rawPrediction, [outputWidth, outputHeight, anchorsLength, classesLength + 5]);
236+
// Box xy_wh
237+
let boxxy = tf.sigmoid(reshaped.slice([0, 0, 0, 0], [outputWidth, outputHeight, anchorsLength, 2]));
238+
let boxwh = tf.exp(reshaped.slice([0, 0, 0, 2], [outputWidth, outputHeight, anchorsLength, 2]));
227239
// objectnessScore
228-
const boxConfidence = tf.sigmoid(reshaped.slice([0, 0, 0, 4], [outputWidth, outputHeight, this.anchorsLength, 1]));
240+
let boxConfidence = tf.sigmoid(reshaped.slice([0, 0, 0, 4], [outputWidth, outputHeight, anchorsLength, 1]));
229241
// classProb
230-
const boxClassProbs = tf.softmax(reshaped.slice([0, 0, 0, 5], [outputWidth, outputHeight, this.anchorsLength, this.classesLength]));
231-
232-
let ConvIndex = tf.range(0, outputWidth);
233-
const ConvHeightIndex = tf.tile(ConvIndex, [outputHeight]);
234-
let ConvWidthindex = tf.tile(tf.expandDims(ConvIndex, 0), [outputWidth, 1]);
235-
ConvWidthindex = tf.transpose(ConvWidthindex).flatten();
236-
ConvIndex = tf.transpose(tf.stack([ConvHeightIndex, ConvWidthindex]));
237-
ConvIndex = tf.reshape(ConvIndex, [outputWidth, outputHeight, 1, 2]);
238-
const ConvDims = tf.reshape(tf.tensor1d([outputWidth, outputHeight]), [1, 1, 1, 2]);
242+
const boxClassProbs = tf.softmax(reshaped.slice([0, 0, 0, 5], [outputWidth, outputHeight, anchorsLength, classesLength]));
239243

240-
const boxxy1 = tf.div(tf.add(boxxy, ConvIndex), ConvDims);
241-
const boxwh1 = tf.div(tf.mul(boxwh, AnchorsTensor), ConvDims);
244+
boxxy = tf.div(tf.add(boxxy, this.convIndex), this.convDims);
245+
boxwh = tf.div(tf.mul(boxwh, this.anchorsTensor), this.convDims);
242246

243-
const finalboxes = tf.concat([boxxy1, boxwh1], 3).reshape([(outputWidth * outputHeight * this.anchorsLength), 4]);
247+
const finalboxes = tf.concat([boxxy, boxwh], 3).reshape([(outputWidth * outputHeight * anchorsLength), 4]);
244248

245249
// filter boxes by objectness threshold
246-
const boxConfidence1 = boxConfidence.squeeze([3]);
247-
const objectnessMask = tf.greaterEqual(boxConfidence1, tf.scalar(this.filterBoxesThreshold));
250+
boxConfidence = boxConfidence.squeeze([3]);
251+
const objectnessMask = tf.greaterEqual(boxConfidence, tf.scalar(this.filterBoxesThreshold));
248252

249253
// filter boxes by class probability threshold
250-
const boxScores1 = tf.mul(boxConfidence1, tf.max(boxClassProbs, 3));
251-
const boxClassProbMask = tf.greaterEqual(boxScores1, tf.scalar(this.classProbThreshold));
254+
const boxScores = tf.mul(boxConfidence, tf.max(boxClassProbs, 3));
255+
const boxClassProbMask = tf.greaterEqual(boxScores, tf.scalar(this.classProbThreshold));
252256

253257
// classes indices
254-
const classes1 = tf.argMax(boxClassProbs, -1);
258+
const classes = tf.argMax(boxClassProbs, -1);
255259

256-
const indicesTensor = tf.range(1, (outputWidth * outputHeight * this.anchorsLength) + 1, 1, 'int32');
260+
const indicesTensor = tf.range(1, (outputWidth * outputHeight * anchorsLength) + 1, 1, 'int32');
257261
// Final Mask each elem that survived both filters
258262
const finalMask = boxClassProbMask.mul(objectnessMask);
263+
259264
const indices = finalMask.flatten().toInt().mul(indicesTensor);
260-
return [finalboxes, boxScores1, classes1, indices];
265+
266+
return [finalboxes, boxScores, classes, indices];
261267
}
262268
}
263269

src/YOLO/index_test.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ describe('YOLO', () => {
2626
beforeEach(async () => {
2727
jasmine.DEFAULT_TIMEOUT_INTERVAL = 100000;
2828
yolo = YOLO();
29-
await yolo.loadModel();
3029
});
3130

3231
it('instantiates the YOLO classifier with defaults', () => {
@@ -38,6 +37,7 @@ describe('YOLO', () => {
3837

3938
it('detects a robin', async () => {
4039
const robin = await getRobin();
40+
await yolo.loadModel();
4141
const detection = await yolo.detect(robin);
4242
expect(detection[0].className).toBe('bird');
4343
});

0 commit comments

Comments
 (0)