Skip to content

Commit 3a158a2

Browse files
committed
redid the YOLO object detection algorithm
1 parent 82843f8 commit 3a158a2

File tree

2 files changed

+230
-271
lines changed

2 files changed

+230
-271
lines changed

src/YOLO/index.js

Lines changed: 230 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -9,160 +9,272 @@ YOLO Object detection
99
Heavily derived from https://github.com/ModelDepot/tfjs-yolo-tiny (ModelDepot: modeldepot.io)
1010
*/
1111

12-
import * as tf from '@tensorflow/tfjs';
13-
import Video from '../utils/Video';
14-
import { imgToTensor } from '../utils/imageUtilities';
15-
16-
import CLASS_NAMES from './../utils/COCO_CLASSES';
17-
18-
import {
19-
nonMaxSuppression,
20-
boxesToCorners,
21-
head,
22-
filterBoxes,
23-
ANCHORS,
24-
} from './postprocess';
25-
26-
const URL = 'https://raw.githubusercontent.com/ml5js/ml5-library/master/src/YOLO/model.json';
12+
import * as tf from "@tensorflow/tfjs";
13+
import CLASS_NAMES from "./../utils/COCO_CLASSES";
2714

2815
const DEFAULTS = {
2916
filterBoxesThreshold: 0.01,
3017
IOUThreshold: 0.4,
3118
classProbThreshold: 0.4,
19+
URL = "https://raw.githubusercontent.com/ml5js/ml5-library/master/src/YOLO/model.json",
3220
};
3321

34-
// Size of the video
35-
const imageSize = 416;
36-
37-
class YOLOBase extends Video {
38-
constructor(video, options, callback) {
39-
super(video, imageSize);
22+
class YOLO {
4023

24+
constructor(options) {
4125
this.filterBoxesThreshold = options.filterBoxesThreshold || DEFAULTS.filterBoxesThreshold;
4226
this.IOUThreshold = options.IOUThreshold || DEFAULTS.IOUThreshold;
4327
this.classProbThreshold = options.classProbThreshold || DEFAULTS.classProbThreshold;
44-
this.modelReady = false;
45-
this.isPredicting = false;
46-
this.loadModel(callback);
47-
}
28+
this.modelURL = options.url || DEFAULTS.URL;
29+
this.model = null;
30+
this.inputWidth = 416;
31+
this.inputHeight = 416;
32+
this.classNames = CLASS_NAMES;
33+
this.anchors = [
34+
[0.57273, 0.677385],
35+
[1.87446, 2.06253],
36+
[3.33843, 5.47434],
37+
[7.88282, 3.52778],
38+
[9.77052, 9.16828]
39+
];
40+
this.scaleX;
41+
this.scaleY;
42+
this.anchorsLength = this.anchors.length;
43+
this.classesLength = this.Params.classNames.length;
44+
this.init();
4845

49-
async loadModel(callback) {
50-
return this.loadVideo().then(async () => {
51-
this.model = await tf.loadModel(URL);
52-
this.modelReady = true;
53-
callback();
54-
});
5546
}
5647

57-
async detect(inputOrCallback, cb = null) {
58-
if (this.modelReady && this.video && !this.predicting) {
59-
let imgToPredict;
60-
let callback = cb;
61-
this.isPredicting = true;
62-
63-
if (inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement) {
64-
imgToPredict = inputOrCallback;
65-
} else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLImageElement || inputOrCallback.elt instanceof HTMLVideoElement)) {
66-
imgToPredict = inputOrCallback.elt; // Handle p5.js image and video.
67-
} else if (typeof inputOrCallback === 'function') {
68-
imgToPredict = this.video;
69-
callback = inputOrCallback;
70-
}
48+
init() {
49+
// indices tensor to filter the elements later on
50+
this.indicesTensor = tf.range(1, 846, 1, "int32");
7151

72-
const input = imgToTensor(imgToPredict);
52+
// Grid To Split the raw predictions : Assumes Our Model output is 1 Tensor with 13x13x425
53+
// gonna hard code all this stuff see if it works
54+
// this can be done once at the initial phase
55+
// TODO : make this more modular
7356

74-
const [allBoxes, boxConfidence, boxClassProbs] = tf.tidy(() => {
75-
const activation = this.model.predict(input);
76-
const [boxXY, boxWH, bConfidence, bClassProbs] = head(activation, ANCHORS, 80);
77-
const aBoxes = boxesToCorners(boxXY, boxWH);
78-
return [aBoxes, bConfidence, bClassProbs];
79-
});
57+
[this.ConvIndex, this.ConvDims, this.AnchorsTensor] = tf.tidy(() => {
58+
let ConvIndex = tf.range(0, 13);
59+
let ConvHeightIndex = tf.tile(ConvIndex, [13]);
8060

81-
const [boxes, scores, classes] = await filterBoxes(allBoxes, boxConfidence, boxClassProbs, this.filterBoxesThreshold);
61+
let ConvWidthindex = tf.tile(tf.expandDims(ConvIndex, 0), [13, 1]);
62+
ConvWidthindex = tf.transpose(ConvWidthindex).flatten();
8263

83-
// If all boxes have been filtered out
84-
if (boxes == null) {
85-
return [];
86-
}
64+
ConvIndex = tf.transpose(tf.stack([ConvHeightIndex, ConvWidthindex]));
65+
ConvIndex = tf.reshape(ConvIndex, [13, 13, 1, 2]);
8766

88-
const width = tf.scalar(imageSize);
89-
const height = tf.scalar(imageSize);
90-
const imageDims = tf.stack([height, width, height, width]).reshape([1, 4]);
91-
const boxesModified = tf.mul(boxes, imageDims);
67+
let ConvDims = tf.reshape(tf.tensor1d([13, 13]), [1, 1, 1, 2]);
68+
//AnchorsTensor
69+
let Aten = tf.tensor2d(this.anchors);
70+
let AnchorsTensor = tf.reshape(Aten, [1, 1, this.anchorsLength, 2]);
9271

93-
const [preKeepBoxesArr, scoresArr] = await Promise.all([
94-
boxesModified.data(), scores.data(),
95-
]);
72+
return [ConvIndex, ConvDims, AnchorsTensor];
73+
});
74+
}
9675

97-
const [keepIndx, boxesArr, keepScores] = nonMaxSuppression(
98-
preKeepBoxesArr,
99-
scoresArr,
100-
this.IOUThreshold,
101-
);
76+
// takes HTMLCanvasElement || HTMLImageElement ||HTMLVideoElement || ImageData as input
77+
// outs results obj
78+
async detect(input) {
79+
const predictions = tf.tidy(() => {
80+
const data = this.preProccess(input);
81+
const preds = this.model.predict(data);
82+
return preds;
83+
})
84+
const results = await this.postProccess(predictions);
85+
return results
86+
}
10287

103-
const classesIndxArr = await classes.gather(tf.tensor1d(keepIndx, 'int32')).data();
88+
async loadModel() {
89+
try {
90+
this.model = await tf.loadModel(this.modelURL);
91+
return true;
92+
} catch (e) {
93+
console.log(e);
94+
return false;
95+
}
96+
}
10497

105-
const results = [];
10698

107-
classesIndxArr.forEach((classIndx, i) => {
108-
const classProb = keepScores[i];
109-
if (classProb < this.classProbThreshold) {
110-
return;
111-
}
99+
//does not dispose of the model atm
100+
dispose() {
101+
tf.disposeconstiables();
102+
}
103+
104+
// should be called after loadModel()
105+
cache() {
106+
tf.tidy(() => {
107+
const dummy = tf.zeros([0, 416, 416, 3])
108+
const data = this.model.predict(dummy)
109+
})
110+
}
112111

113-
const className = CLASS_NAMES[classIndx];
114-
let [y, x, h, w] = boxesArr[i];
115112

116-
y = Math.max(0, y);
117-
x = Math.max(0, x);
118-
h = Math.min(imageSize, h) - y;
119-
w = Math.min(imageSize, w) - x;
120113

121-
const resultObj = {
122-
className,
123-
classProb,
124-
x: x / imageSize,
125-
y: y / imageSize,
126-
w: w / imageSize,
127-
h: h / imageSize,
128-
};
114+
preProccess(input) {
115+
let img = tf.fromPixels(input)
116+
this.imgWidth = img.shape[1];
117+
this.imgHeight = img.shape[0];
118+
img = tf.image.resizeBilinear(img, [this.inputHeight, this.inputWidth])
119+
.toFloat()
120+
.div(tf.scalar(255))
121+
.expandDims(0);
122+
//Scale Stuff
123+
this.scaleX = this.imgHeight / this.inputHeight;
124+
this.scaleY = this.imgWidth / this.inputWidth;
125+
return img
126+
}
129127

130-
results.push(resultObj);
131-
});
132128

133-
await tf.nextFrame();
134-
this.isPredicting = false;
129+
async postProccess(rawPrediction) {
130+
131+
let results = { totalDetections: 0, detections: [] }
132+
133+
const [boxes, BoxScores, Classes, Indices] = tf.tidy(() => {
134+
135+
rawPrediction = tf.reshape(rawPrediction, [13, 13, this.anchorsLength, this.classesLength + 5]);
136+
// Box Coords
137+
let BoxXY = tf.sigmoid(rawPrediction.slice([0, 0, 0, 0], [13, 13, this.anchorsLength, 2]))
138+
let BoxWH = tf.exp(rawPrediction.slice([0, 0, 0, 2], [13, 13, this.anchorsLength, 2]))
139+
// ObjectnessScore
140+
let BoxConfidence = tf.sigmoid(rawPrediction.slice([0, 0, 0, 4], [13, 13, this.anchorsLength, 1]))
141+
// ClassProb
142+
let BoxClassProbs = tf.softmax(rawPrediction.slice([0, 0, 0, 5], [13, 13, this.anchorsLength, this.classesLength]));
143+
144+
// from boxes with xy wh to x1,y1 x2,y2
145+
// Mainly for NMS + rescaling
146+
/*
147+
x1 = x + (h/2)
148+
y1 = y - (w/2)
149+
x2 = x - (h/2)
150+
y2 = y + (w/2)
151+
*/
152+
// BoxScale
153+
BoxXY = tf.div(tf.add(BoxXY, this.ConvIndex), this.ConvDims);
154+
155+
BoxWH = tf.div(tf.mul(BoxWH, this.AnchorsTensor), this.ConvDims);
156+
157+
const Div = tf.div(BoxWH, tf.scalar(2))
158+
159+
const BoxMins = tf.sub(BoxXY, Div);
160+
161+
const BoxMaxes = tf.add(BoxXY, Div);
162+
const Size = [BoxMins.shape[0], BoxMins.shape[1], BoxMins.shape[2], 1];
163+
164+
// main box tensor
165+
const boxes = tf.concat([BoxMins.slice([0, 0, 0, 1], Size),
166+
BoxMins.slice([0, 0, 0, 0], Size),
167+
BoxMaxes.slice([0, 0, 0, 1], Size),
168+
BoxMaxes.slice([0, 0, 0, 0], Size)
169+
], 3)
170+
.reshape([845, 4])
171+
172+
173+
// Filterboxes by objectness threshold
174+
// not filtering / getting a mask really
175+
176+
BoxConfidence = BoxConfidence.squeeze([3])
177+
const ObjectnessMask = tf.greaterEqual(BoxConfidence, tf.scalar(this.filterboxesThreshold))
178+
179+
180+
// Filterboxes by class probability threshold
181+
const BoxScores = tf.mul(BoxConfidence, tf.max(BoxClassProbs, 3));
182+
const BoxClassProbMask = tf.greaterEqual(BoxScores, tf.scalar(this.classProbThreshold));
183+
184+
// getting classes indices
185+
const Classes = tf.argMax(BoxClassProbs, -1)
186+
187+
188+
// Final Mask each elem that survived both filters (0x0 0x1 1x0 = fail ) 1x1 = survived
189+
const FinalMask = BoxClassProbMask.mul(ObjectnessMask)
190+
191+
const Indices = FinalMask.flatten().toInt().mul(this.IndicesTensor)
192+
return [boxes, BoxScores, Classes, Indices]
193+
})
194+
195+
//we started at one in the range so we remove 1 now
196+
197+
let indicesArr = Array.from(await Indices.data()).filter(i => i > 0).map(i => i - 1);
198+
199+
if (indicesArr.length == 0) {
200+
boxes.dispose()
201+
BoxScores.dispose()
202+
Classes.dispose()
203+
return results
204+
}
205+
const indicesTensor = tf.tensor1d(indicesArr, "int32");
206+
let filteredBoxes = boxes.gather(indicesTensor)
207+
let filteredScores = BoxScores.flatten().gather(indicesTensor)
208+
let filteredClasses = Classes.flatten().gather(indicesTensor)
209+
boxes.dispose()
210+
BoxScores.dispose()
211+
Classes.dispose()
212+
indicesTensor.dispose()
213+
214+
//Img Rescale
215+
const Height = tf.scalar(this.imgHeight);
216+
const Width = tf.scalar(this.imgWidth)
217+
const ImageDims = tf.stack([Height, Width, Height, Width]).reshape([1, 4]);
218+
filteredBoxes = filteredBoxes.mul(ImageDims)
219+
220+
// NonMaxSuppression
221+
// GreedyNMS
222+
const [boxArr, scoreArr, classesArr] = await Promise.all([filteredBoxes.data(), filteredScores.data(), filteredClasses.data()]);
223+
filteredBoxes.dispose()
224+
filteredScores.dispose()
225+
filteredClasses.dispose()
226+
227+
let zipped = [];
228+
for (let i = 0; i < scoreArr.length; i++) {
229+
// [Score,x,y,w,h,classindex]
230+
zipped.push([scoreArr[i], [boxArr[4 * i], boxArr[4 * i + 1], boxArr[4 * i + 2], boxArr[4 * i + 3]], classesArr[i]]);
231+
}
135232

136-
if (callback) {
137-
callback(results);
233+
// Sort by descending order of scores (first index of zipped array)
234+
const sorted = zipped.sort((a, b) => b[0] - a[0]);
235+
const selectedBoxes = []
236+
// Greedily go through boxes in descending score order and only
237+
// return boxes that are below the IoU threshold.
238+
sorted.forEach(box => {
239+
let Push = true;
240+
for (let i = 0; i < selectedBoxes.length; i++) {
241+
// Compare IoU of zipped[1], since that is the box coordinates arr
242+
let w = Math.min(box[1][3], selectedBoxes[i][1][3]) - Math.max(box[1][1], selectedBoxes[i][1][1]);
243+
let h = Math.min(box[1][2], selectedBoxes[i][1][2]) - Math.max(box[1][0], selectedBoxes[i][1][0]);
244+
let Intersection = w < 0 || h < 0 ? 0 : w * h
245+
let Union = (box[1][3] - box[1][1]) * (box[1][2] - box[1][0]) + (selectedBoxes[i][1][3] - selectedBoxes[i][1][1]) * (selectedBoxes[i][1][2] - selectedBoxes[i][1][0]) - Intersection
246+
let Iou = Intersection / Union
247+
if (Iou > this.IOUThreshold) {
248+
Push = false;
249+
break;
250+
}
138251
}
252+
if (Push) selectedBoxes.push(box);
253+
});
139254

140-
return results;
141-
}
142-
console.warn('Model has not finished loading');
143-
return false;
144-
}
145-
}
255+
// final phase
146256

147-
const YOLO = (videoOrOptionsOrCallback, optionsOrCallback, cb = () => {}) => {
148-
let callback = cb;
149-
let options = {};
150-
const video = videoOrOptionsOrCallback;
257+
// add any output you want
258+
for (let id = 0; id < selectedBoxes.length; id++) {
151259

152-
if (typeof videoOrOptionsOrCallback === 'object') {
153-
options = videoOrOptionsOrCallback;
154-
} else if (typeof videoOrOptionsOrCallback === 'function') {
155-
callback = videoOrOptionsOrCallback;
156-
}
260+
const classProb = selectedBoxes[id][0];
261+
const classProbRounded = Math.round(classProb * 1000) / 10
262+
const className = this.classNames[selectedBoxes[id][2]];
263+
const classIndex = selectedBoxes[id][2];
264+
const [x1, y1, x2, y2] = selectedBoxes[id][1];
265+
// Need to get this out
266+
// TODO : add a hsla color for later visualization
267+
const resultObj = { id, className, classIndex, classProb, classProbRounded, x1, y1, x2, y2 };
268+
results.detections.push(resultObj);
269+
}
270+
// Misc
271+
results.totalDetections = results.detections.length;
272+
results.scaleX = this.scaleX
273+
results.scaleY = this.scaleY
157274

158-
if (typeof optionsOrCallback === 'object') {
159-
options = optionsOrCallback;
160-
} else if (typeof optionsOrCallback === 'function') {
161-
callback = optionsOrCallback;
275+
return results
162276
}
163277

164-
return new YOLOBase(video, options, callback);
165-
};
278+
}
166279

167280
export default YOLO;
168-

0 commit comments

Comments
 (0)