Skip to content

Commit 68f6f72

Browse files
committed
woops
1 parent 3a158a2 commit 68f6f72

File tree

1 file changed

+157
-105
lines changed

1 file changed

+157
-105
lines changed

src/YOLO/index.js

Lines changed: 157 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@ const DEFAULTS = {
1616
filterBoxesThreshold: 0.01,
1717
IOUThreshold: 0.4,
1818
classProbThreshold: 0.4,
19-
URL = "https://raw.githubusercontent.com/ml5js/ml5-library/master/src/YOLO/model.json",
19+
URL: "https://raw.githubusercontent.com/ml5js/ml5-library/master/src/YOLO/model.json"
2020
};
2121

2222
class YOLO {
23-
2423
constructor(options) {
25-
this.filterBoxesThreshold = options.filterBoxesThreshold || DEFAULTS.filterBoxesThreshold;
24+
this.filterBoxesThreshold =
25+
options.filterBoxesThreshold || DEFAULTS.filterBoxesThreshold;
2626
this.IOUThreshold = options.IOUThreshold || DEFAULTS.IOUThreshold;
27-
this.classProbThreshold = options.classProbThreshold || DEFAULTS.classProbThreshold;
27+
this.classProbThreshold =
28+
options.classProbThreshold || DEFAULTS.classProbThreshold;
2829
this.modelURL = options.url || DEFAULTS.URL;
2930
this.model = null;
3031
this.inputWidth = 416;
@@ -40,19 +41,18 @@ class YOLO {
4041
this.scaleX;
4142
this.scaleY;
4243
this.anchorsLength = this.anchors.length;
43-
this.classesLength = this.Params.classNames.length;
44+
this.classesLength = this.classNames.length;
4445
this.init();
45-
4646
}
4747

4848
init() {
49-
// indices tensor to filter the elements later on
49+
// indices tensor to filter the elements later on
5050
this.indicesTensor = tf.range(1, 846, 1, "int32");
5151

5252
// Grid To Split the raw predictions : Assumes Our Model output is 1 Tensor with 13x13x425
5353
// gonna hard code all this stuff see if it works
54-
// this can be done once at the initial phase
55-
// TODO : make this more modular
54+
// this can be done once at the initial phase
55+
// TODO : make this more modular
5656

5757
[this.ConvIndex, this.ConvDims, this.AnchorsTensor] = tf.tidy(() => {
5858
let ConvIndex = tf.range(0, 13);
@@ -73,16 +73,16 @@ class YOLO {
7373
});
7474
}
7575

76-
// takes HTMLCanvasElement || HTMLImageElement ||HTMLVideoElement || ImageData as input
76+
// takes HTMLCanvasElement || HTMLImageElement ||HTMLVideoElement || ImageData as input
7777
// outs results obj
7878
async detect(input) {
7979
const predictions = tf.tidy(() => {
8080
const data = this.preProccess(input);
8181
const preds = this.model.predict(data);
8282
return preds;
83-
})
83+
});
8484
const results = await this.postProccess(predictions);
85-
return results
85+
return results;
8686
}
8787

8888
async loadModel() {
@@ -95,54 +95,65 @@ class YOLO {
9595
}
9696
}
9797

98-
99-
//does not dispose of the model atm
98+
//does not dispose of the model atm
10099
dispose() {
101100
tf.disposeconstiables();
102101
}
103102

104103
// should be called after loadModel()
105104
cache() {
106105
tf.tidy(() => {
107-
const dummy = tf.zeros([0, 416, 416, 3])
108-
const data = this.model.predict(dummy)
109-
})
106+
const dummy = tf.zeros([0, 416, 416, 3]);
107+
const data = this.model.predict(dummy);
108+
});
110109
}
111110

112-
113-
114111
preProccess(input) {
115-
let img = tf.fromPixels(input)
112+
let img = tf.fromPixels(input);
116113
this.imgWidth = img.shape[1];
117114
this.imgHeight = img.shape[0];
118-
img = tf.image.resizeBilinear(img, [this.inputHeight, this.inputWidth])
115+
img = tf.image
116+
.resizeBilinear(img, [this.inputHeight, this.inputWidth])
119117
.toFloat()
120118
.div(tf.scalar(255))
121119
.expandDims(0);
122120
//Scale Stuff
123121
this.scaleX = this.imgHeight / this.inputHeight;
124122
this.scaleY = this.imgWidth / this.inputWidth;
125-
return img
123+
return img;
126124
}
127125

128-
129126
async postProccess(rawPrediction) {
130-
131-
let results = { totalDetections: 0, detections: [] }
132-
133-
const [boxes, BoxScores, Classes, Indices] = tf.tidy(() => {
134-
135-
rawPrediction = tf.reshape(rawPrediction, [13, 13, this.anchorsLength, this.classesLength + 5]);
127+
let results = { totalDetections: 0, detections: [] };
128+
129+
const [boxes, boxScores, classes, Indices] = tf.tidy(() => {
130+
rawPrediction = tf.reshape(rawPrediction, [
131+
13,
132+
13,
133+
this.anchorsLength,
134+
this.classesLength + 5
135+
]);
136136
// Box Coords
137-
let BoxXY = tf.sigmoid(rawPrediction.slice([0, 0, 0, 0], [13, 13, this.anchorsLength, 2]))
138-
let BoxWH = tf.exp(rawPrediction.slice([0, 0, 0, 2], [13, 13, this.anchorsLength, 2]))
137+
let BoxXY = tf.sigmoid(
138+
rawPrediction.slice([0, 0, 0, 0], [13, 13, this.anchorsLength, 2])
139+
);
140+
let BoxWH = tf.exp(
141+
rawPrediction.slice([0, 0, 0, 2], [13, 13, this.anchorsLength, 2])
142+
);
139143
// ObjectnessScore
140-
let BoxConfidence = tf.sigmoid(rawPrediction.slice([0, 0, 0, 4], [13, 13, this.anchorsLength, 1]))
144+
let BoxConfidence = tf.sigmoid(
145+
rawPrediction.slice([0, 0, 0, 4], [13, 13, this.anchorsLength, 1])
146+
);
141147
// ClassProb
142-
let BoxClassProbs = tf.softmax(rawPrediction.slice([0, 0, 0, 5], [13, 13, this.anchorsLength, this.classesLength]));
143-
148+
let BoxClassProbs = tf.softmax(
149+
rawPrediction.slice(
150+
[0, 0, 0, 5],
151+
[13, 13, this.anchorsLength, this.classesLength]
152+
)
153+
);
154+
144155
// from boxes with xy wh to x1,y1 x2,y2
145-
// Mainly for NMS + rescaling
156+
// Mainly for NMS + rescaling
146157
/*
147158
x1 = x + (h/2)
148159
y1 = y - (w/2)
@@ -151,99 +162,132 @@ class YOLO {
151162
*/
152163
// BoxScale
153164
BoxXY = tf.div(tf.add(BoxXY, this.ConvIndex), this.ConvDims);
154-
155-
BoxWH = tf.div(tf.mul(BoxWH, this.AnchorsTensor), this.ConvDims);
156-
157-
const Div = tf.div(BoxWH, tf.scalar(2))
158-
159-
const BoxMins = tf.sub(BoxXY, Div);
160-
161-
const BoxMaxes = tf.add(BoxXY, Div);
162-
const Size = [BoxMins.shape[0], BoxMins.shape[1], BoxMins.shape[2], 1];
163-
164-
// main box tensor
165-
const boxes = tf.concat([BoxMins.slice([0, 0, 0, 1], Size),
166-
BoxMins.slice([0, 0, 0, 0], Size),
167-
BoxMaxes.slice([0, 0, 0, 1], Size),
168-
BoxMaxes.slice([0, 0, 0, 0], Size)
169-
], 3)
170-
.reshape([845, 4])
171-
172165

173-
// Filterboxes by objectness threshold
174-
// not filtering / getting a mask really
175-
176-
BoxConfidence = BoxConfidence.squeeze([3])
177-
const ObjectnessMask = tf.greaterEqual(BoxConfidence, tf.scalar(this.filterboxesThreshold))
166+
BoxWH = tf.div(tf.mul(BoxWH, this.AnchorsTensor), this.ConvDims);
178167

168+
const Div = tf.div(BoxWH, tf.scalar(2));
179169

180-
// Filterboxes by class probability threshold
181-
const BoxScores = tf.mul(BoxConfidence, tf.max(BoxClassProbs, 3));
182-
const BoxClassProbMask = tf.greaterEqual(BoxScores, tf.scalar(this.classProbThreshold));
170+
const BoxMins = tf.sub(BoxXY, Div);
183171

184-
// getting classes indices
185-
const Classes = tf.argMax(BoxClassProbs, -1)
172+
const BoxMaxes = tf.add(BoxXY, Div);
173+
const Size = [BoxMins.shape[0], BoxMins.shape[1], BoxMins.shape[2], 1];
186174

175+
// main box tensor
176+
const boxes = tf
177+
.concat(
178+
[
179+
BoxMins.slice([0, 0, 0, 1], Size),
180+
BoxMins.slice([0, 0, 0, 0], Size),
181+
BoxMaxes.slice([0, 0, 0, 1], Size),
182+
BoxMaxes.slice([0, 0, 0, 0], Size)
183+
],
184+
3
185+
)
186+
.reshape([845, 4]);
187+
188+
// Filterboxes by objectness threshold
189+
// not filtering / getting a mask really
190+
191+
BoxConfidence = BoxConfidence.squeeze([3]);
192+
const ObjectnessMask = tf.greaterEqual(
193+
BoxConfidence,
194+
tf.scalar(this.filterboxesThreshold)
195+
);
196+
197+
// Filterboxes by class probability threshold
198+
const boxScores = tf.mul(BoxConfidence, tf.max(BoxClassProbs, 3));
199+
const BoxClassProbMask = tf.greaterEqual(
200+
boxScores,
201+
tf.scalar(this.classProbThreshold)
202+
);
203+
204+
// getting classes indices
205+
const classes = tf.argMax(BoxClassProbs, -1);
187206

188207
// Final Mask each elem that survived both filters (0x0 0x1 1x0 = fail ) 1x1 = survived
189-
const FinalMask = BoxClassProbMask.mul(ObjectnessMask)
208+
const FinalMask = BoxClassProbMask.mul(ObjectnessMask);
190209

191-
const Indices = FinalMask.flatten().toInt().mul(this.IndicesTensor)
192-
return [boxes, BoxScores, Classes, Indices]
193-
})
210+
const Indices = FinalMask.flatten()
211+
.toInt()
212+
.mul(this.IndicesTensor);
213+
return [boxes, boxScores, classes, Indices];
214+
});
194215

195-
//we started at one in the range so we remove 1 now
216+
//we started at one in the range so we remove 1 now
196217

197-
let indicesArr = Array.from(await Indices.data()).filter(i => i > 0).map(i => i - 1);
218+
let indicesArr = Array.from(await Indices.data())
219+
.filter(i => i > 0)
220+
.map(i => i - 1);
198221

199222
if (indicesArr.length == 0) {
200-
boxes.dispose()
201-
BoxScores.dispose()
202-
Classes.dispose()
203-
return results
223+
boxes.dispose();
224+
boxScores.dispose();
225+
classes.dispose();
226+
return results;
204227
}
205228
const indicesTensor = tf.tensor1d(indicesArr, "int32");
206-
let filteredBoxes = boxes.gather(indicesTensor)
207-
let filteredScores = BoxScores.flatten().gather(indicesTensor)
208-
let filteredClasses = Classes.flatten().gather(indicesTensor)
209-
boxes.dispose()
210-
BoxScores.dispose()
211-
Classes.dispose()
212-
indicesTensor.dispose()
229+
let filteredBoxes = boxes.gather(indicesTensor);
230+
let filteredScores = boxScores.flatten().gather(indicesTensor);
231+
let filteredclasses = classes.flatten().gather(indicesTensor);
232+
boxes.dispose();
233+
boxScores.dispose();
234+
classes.dispose();
235+
indicesTensor.dispose();
213236

214237
//Img Rescale
215238
const Height = tf.scalar(this.imgHeight);
216-
const Width = tf.scalar(this.imgWidth)
239+
const Width = tf.scalar(this.imgWidth);
217240
const ImageDims = tf.stack([Height, Width, Height, Width]).reshape([1, 4]);
218-
filteredBoxes = filteredBoxes.mul(ImageDims)
241+
filteredBoxes = filteredBoxes.mul(ImageDims);
219242

220-
// NonMaxSuppression
243+
// NonMaxSuppression
221244
// GreedyNMS
222-
const [boxArr, scoreArr, classesArr] = await Promise.all([filteredBoxes.data(), filteredScores.data(), filteredClasses.data()]);
223-
filteredBoxes.dispose()
224-
filteredScores.dispose()
225-
filteredClasses.dispose()
245+
const [boxArr, scoreArr, classesArr] = await Promise.all([
246+
filteredBoxes.data(),
247+
filteredScores.data(),
248+
filteredclasses.data()
249+
]);
250+
filteredBoxes.dispose();
251+
filteredScores.dispose();
252+
filteredclasses.dispose();
226253

227254
let zipped = [];
228255
for (let i = 0; i < scoreArr.length; i++) {
229256
// [Score,x,y,w,h,classindex]
230-
zipped.push([scoreArr[i], [boxArr[4 * i], boxArr[4 * i + 1], boxArr[4 * i + 2], boxArr[4 * i + 3]], classesArr[i]]);
257+
zipped.push([
258+
scoreArr[i],
259+
[
260+
boxArr[4 * i],
261+
boxArr[4 * i + 1],
262+
boxArr[4 * i + 2],
263+
boxArr[4 * i + 3]
264+
],
265+
classesArr[i]
266+
]);
231267
}
232268

233269
// Sort by descending order of scores (first index of zipped array)
234270
const sorted = zipped.sort((a, b) => b[0] - a[0]);
235-
const selectedBoxes = []
271+
const selectedBoxes = [];
236272
// Greedily go through boxes in descending score order and only
237273
// return boxes that are below the IoU threshold.
238274
sorted.forEach(box => {
239275
let Push = true;
240276
for (let i = 0; i < selectedBoxes.length; i++) {
241277
// Compare IoU of zipped[1], since that is the box coordinates arr
242-
let w = Math.min(box[1][3], selectedBoxes[i][1][3]) - Math.max(box[1][1], selectedBoxes[i][1][1]);
243-
let h = Math.min(box[1][2], selectedBoxes[i][1][2]) - Math.max(box[1][0], selectedBoxes[i][1][0]);
244-
let Intersection = w < 0 || h < 0 ? 0 : w * h
245-
let Union = (box[1][3] - box[1][1]) * (box[1][2] - box[1][0]) + (selectedBoxes[i][1][3] - selectedBoxes[i][1][1]) * (selectedBoxes[i][1][2] - selectedBoxes[i][1][0]) - Intersection
246-
let Iou = Intersection / Union
278+
let w =
279+
Math.min(box[1][3], selectedBoxes[i][1][3]) -
280+
Math.max(box[1][1], selectedBoxes[i][1][1]);
281+
let h =
282+
Math.min(box[1][2], selectedBoxes[i][1][2]) -
283+
Math.max(box[1][0], selectedBoxes[i][1][0]);
284+
let Intersection = w < 0 || h < 0 ? 0 : w * h;
285+
let Union =
286+
(box[1][3] - box[1][1]) * (box[1][2] - box[1][0]) +
287+
(selectedBoxes[i][1][3] - selectedBoxes[i][1][1]) *
288+
(selectedBoxes[i][1][2] - selectedBoxes[i][1][0]) -
289+
Intersection;
290+
let Iou = Intersection / Union;
247291
if (Iou > this.IOUThreshold) {
248292
Push = false;
249293
break;
@@ -252,29 +296,37 @@ class YOLO {
252296
if (Push) selectedBoxes.push(box);
253297
});
254298

255-
// final phase
299+
// final phase
256300

257301
// add any output you want
258302
for (let id = 0; id < selectedBoxes.length; id++) {
259-
260303
const classProb = selectedBoxes[id][0];
261-
const classProbRounded = Math.round(classProb * 1000) / 10
304+
const classProbRounded = Math.round(classProb * 1000) / 10;
262305
const className = this.classNames[selectedBoxes[id][2]];
263306
const classIndex = selectedBoxes[id][2];
264307
const [x1, y1, x2, y2] = selectedBoxes[id][1];
265-
// Need to get this out
266-
// TODO : add a hsla color for later visualization
267-
const resultObj = { id, className, classIndex, classProb, classProbRounded, x1, y1, x2, y2 };
308+
// Need to get this out
309+
// TODO : add a hsla color for later visualization
310+
const resultObj = {
311+
id,
312+
className,
313+
classIndex,
314+
classProb,
315+
classProbRounded,
316+
x1,
317+
y1,
318+
x2,
319+
y2
320+
};
268321
results.detections.push(resultObj);
269322
}
270-
// Misc
323+
// Misc
271324
results.totalDetections = results.detections.length;
272-
results.scaleX = this.scaleX
273-
results.scaleY = this.scaleY
325+
results.scaleX = this.scaleX;
326+
results.scaleY = this.scaleY;
274327

275-
return results
328+
return results;
276329
}
277-
278330
}
279331

280332
export default YOLO;

0 commit comments

Comments
 (0)