Skip to content

Commit 9375570

Browse files
committed
refactor ObjectDetector Class to be main interface for models
1 parent ce82f51 commit 9375570

File tree

4 files changed

+140
-142
lines changed

4 files changed

+140
-142
lines changed

examples/objectDetection/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
<meta charset="UTF-8" />
55
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
66
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
7-
<title>ml5.js objectDetection Webcam Example</title>
7+
<title>ml5.js objectDetector Webcam Example</title>
88
<script src="https://cdn.jsdelivr.net/npm/[email protected]/lib/p5.js"></script>
99
<script src="../../dist/ml5.js"></script>
1010
</head>

examples/objectDetection/sketch.js

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,32 @@ function setup() {
2121
createCanvas(640, 480);
2222

2323
video = createCapture(VIDEO);
24-
video.size(640, 480);
24+
video.size(width, height);
2525
video.hide();
2626

27-
// detector.detect(video, gotDetections);
2827
detector.detectStart(video, gotDetections);
2928
}
3029

3130
function gotDetections(results) {
3231
detections = results;
33-
// console.log(results);
34-
// detector.detect(video, gotDetections);
3532
}
3633

3734
function draw() {
3835
image(video, 0, 0);
3936

4037
for (let i = 0; i < detections.length; i += 1) {
41-
const object = detections[i];
38+
const detection = detections[i];
4239

4340
// draw bounding box
4441
stroke(0, 255, 0);
4542
strokeWeight(4);
4643
noFill();
47-
rect(object.x, object.y, object.width, object.height);
44+
rect(detection.x, detection.y, detection.width, detection.height);
4845

4946
// draw label
5047
noStroke();
5148
fill(255);
5249
textSize(24);
53-
text(object.label, object.x + 10, object.y + 24);
50+
text(detection.label, detection.x + 10, detection.y + 24);
5451
}
5552
}

src/ObjectDetector/cocossd.js

Lines changed: 17 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -4,86 +4,46 @@
44
// https://opensource.org/licenses/MIT
55

66
/*
7-
COCO-SSD Object detection
7+
COCO-SSD Object detection model
88
Wraps the coco-ssd model in tfjs to be used in ml5
99
*/
1010
import * as tf from "@tensorflow/tfjs";
1111
import * as cocoSsd from "@tensorflow-models/coco-ssd";
12-
import callCallback from "../utils/callcallback";
13-
import handleArguments from "../utils/handleArguments";
1412
import { mediaReady } from "../utils/imageUtilities";
1513

1614
const DEFAULTS = {
1715
base: "lite_mobilenet_v2",
1816
modelUrl: undefined,
1917
};
2018

21-
export class CocoSsdBase {
22-
/**
23-
* Create CocoSsd model. Works on video and images.
24-
* @param {function} constructorCallback - Optional. A callback function that is called once the model has loaded. If no callback is provided, it will return a promise
25-
* that will be resolved once the model has loaded.
26-
*/
27-
constructor(video, options, constructorCallback) {
28-
this.video = video || null;
29-
this.modelReady = false;
30-
31-
this.isPredicting = false;
32-
this.signalStop = false;
33-
this.prevCall = "";
34-
19+
export class CocoSsd {
20+
constructor(options = {}) {
21+
this.model = null;
3522
this.config = {
3623
base: options.base || DEFAULTS.base,
3724
modelUrl: options.modelUrl || DEFAULTS.modelUrl,
3825
};
39-
this.callback = constructorCallback;
40-
41-
this.ready = callCallback(this.loadModel(), this.callback);
4226
}
4327

44-
async loadModel() {
28+
async load() {
4529
await tf.setBackend("webgl"); // this line resolves warning : performance is poor on webgpu backend
4630
await tf.ready();
4731

4832
this.model = await cocoSsd.load(this.config);
49-
50-
this.modelReady = true;
5133
return this;
5234
}
5335

5436
/**
55-
* @typedef {Object} ObjectDetectorPrediction
56-
* @property {number} x - top left x coordinate of the prediction box in pixels.
57-
* @property {number} y - top left y coordinate of the prediction box in pixels.
58-
* @property {number} width - width of the prediction box in pixels.
59-
* @property {number} height - height of the prediction box in pixels.
60-
* @property {string} label - the label given.
61-
* @property {number} confidence - the confidence score (0 to 1).
62-
* @property {ObjectDetectorPredictionNormalized} normalized - a normalized object of the predicition
37+
* Detect objects that are in the image/video/canvas
38+
* @param {HTMLVideoElement|HTMLImageElement|HTMLCanvasElement|ImageData} imgToPredict - Subject of the detection.
39+
* @returns {Array} Array of detection detections
6340
*/
64-
65-
/**
66-
* @typedef {Object} ObjectDetectorPredictionNormalized
67-
* @property {number} x - top left x coordinate of the prediction box (0 to 1).
68-
* @property {number} y - top left y coordinate of the prediction box (0 to 1).
69-
* @property {number} width - width of the prediction box (0 to 1).
70-
* @property {number} height - height of the prediction box (0 to 1).
71-
*/
72-
/**
73-
* Detect objects that are in video, returns bounding box, label, and confidence scores
74-
* @param {HTMLVideoElement|HTMLImageElement|HTMLCanvasElement|ImageData} subject - Subject of the detection.
75-
* @returns {ObjectDetectorPrediction}
76-
*/
77-
async detectInternal(imgToPredict) {
78-
// this.isPredicting = true;
79-
await this.ready;
80-
41+
async detect(imgToPredict) {
8142
mediaReady(imgToPredict, true);
82-
8343
await tf.nextFrame();
8444

85-
const predictions = await this.model.detect(imgToPredict);
86-
const formattedPredictions = predictions.map(prediction => {
45+
const detections = await this.model.detect(imgToPredict);
46+
const formattedDetections = detections.map(prediction => {
8747
return {
8848
label: prediction.class,
8949
confidence: prediction.score,
@@ -100,66 +60,12 @@ export class CocoSsdBase {
10060
};
10161
});
10262

103-
this.isPredicting = false;
104-
105-
return formattedPredictions;
106-
}
107-
108-
/**
109-
* Detect objects that are in video, returns bounding box, label, and confidence scores
110-
* @param {HTMLVideoElement|HTMLImageElement|HTMLCanvasElement|ImageData} subject - Subject of the detection.
111-
* @param {function} callback - Optional. A callback function that is called once the model has loaded. If no callback is provided, it will return a promise
112-
* that will be resolved once the prediction is done.
113-
* @returns {ObjectDetectorPrediction}
114-
*/
115-
async detect(inputOrCallback, cb) {
116-
const args = handleArguments(this.video, inputOrCallback, cb);
117-
args.require("image", "Detection subject not supported");
118-
119-
return callCallback(this.detectInternal(args.image), args.callback);
120-
}
121-
122-
async detectStart(inputNumOrCallback, numOrCallback, cb){
123-
const { image, number, callback } = handleArguments(
124-
inputNumOrCallback,
125-
numOrCallback,
126-
cb
127-
).require("image", "No input provided.");
128-
129-
const detectFrame = async () => {
130-
await mediaReady(image, true);
131-
132-
await callCallback(this.detectInternal(image), callback);
133-
134-
if(!this.signalStop){
135-
requestAnimationFrame(detectFrame);
136-
} else {
137-
this.isPredicting = false;
138-
}
139-
};
140-
141-
// start the detection
142-
this.signalStop = false;
143-
if (!this.isPredicting) {
144-
this.isPredicting = true;
145-
detectFrame();
146-
}
147-
148-
if (this.prevCall === "start") {
149-
console.warn("warning");
150-
}
151-
this.prevCall = "start";
152-
}
153-
154-
detectStop(){
155-
if (this.isPredicting) { this.signalStop = true; }
156-
this.prevCall = "stop";
63+
return formattedDetections;
15764
}
15865
}
15966

160-
export const CocoSsd = (...inputs) => {
161-
const { video, options = {}, callback } = handleArguments(...inputs);
162-
return new CocoSsdBase(video, options, callback);
163-
};
164-
165-
export default CocoSsd;
67+
export async function load(modelConfig = {}) {
68+
const cocoSsdInstance = new CocoSsd(modelConfig);
69+
await cocoSsdInstance.load();
70+
return cocoSsdInstance;
71+
}

src/ObjectDetector/index.js

Lines changed: 118 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@
77
ObjectDetection
88
*/
99

10+
import * as cocoSsd from "./cocossd.js";
11+
import { handleModelName } from "../utils/handleOptions";
1012
import handleArguments from "../utils/handleArguments";
11-
import {CocoSsd} from "./cocossd.js";
13+
import callCallback from "../utils/callcallback";
14+
import { mediaReady } from "../utils/imageUtilities";
15+
16+
const MODEL_OPTIONS = ["cocossd"]; // Expandable for other models like YOLO
1217

1318
class ObjectDetector {
1419
/**
@@ -24,38 +29,128 @@ class ObjectDetector {
2429
* @param {Object} options - Optional. A set of options.
2530
* @param {function} callback - Optional. A callback function that is called once the model has loaded.
2631
*/
27-
constructor(modelNameOrUrl, video, options, callback) {
28-
this.video = video || null;
29-
this.modelNameOrUrl = modelNameOrUrl;
30-
this.options = options || {};
31-
this.callback = callback;
32+
constructor(modelNameOrUrl, options = {}, callback) {
33+
this.model = null;
34+
this.modelName = null;
35+
this.modelToUse = null;
36+
37+
// flags for detectStart() and detectStop()
38+
this.isDetecting = false;
39+
this.signalStop = false;
40+
this.prevCall = "";
41+
42+
this.modelName = handleModelName(
43+
modelNameOrUrl,
44+
MODEL_OPTIONS,
45+
"cocossd",
46+
"objectDetector"
47+
);
48+
3249

33-
switch (modelNameOrUrl) {
50+
switch (this.modelName) {
3451
case "cocossd":
35-
this.model = CocoSsd(this.video, this.options, callback);
36-
return this;
52+
this.modelToUse = cocoSsd;
53+
break;
54+
case "yolo":
55+
this.modelToUse = yolo;
56+
break;
57+
// more models... currently only cocossd is supported
3758
default:
38-
// use cocossd as default
39-
this.model = CocoSsd(this.video, this.options, callback);
40-
return this;
59+
console.warn(`Unknown model: ${this.modelName}, defaulting to CocoSsd`);
60+
this.modelToUse = cocoSsd;
4161
}
62+
63+
// load model and assign ready promise
64+
this.ready = callCallback(this.loadModel(options), callback);
4265
}
43-
}
4466

45-
const objectDetector = (...inputs) => {
46-
const { video, options = {}, callback, string } = handleArguments(...inputs)
47-
.require('string', 'Please specify a model to use. E.g: "YOLO"');
67+
async loadModel(options) {
68+
if (!this.modelToUse || !this.modelToUse.load) {
69+
throw new Error(`Model loader is missing or invalid for: ${this.modelName}`);
70+
}
71+
72+
this.model = await this.modelToUse.load(options);
73+
74+
return this;
75+
}
76+
77+
/**
78+
* @typedef {Object} ObjectDetectorPrediction
79+
* @property {number} x - top left x coordinate of the prediction box in pixels.
80+
* @property {number} y - top left y coordinate of the prediction box in pixels.
81+
* @property {number} width - width of the prediction box in pixels.
82+
* @property {number} height - height of the prediction box in pixels.
83+
* @property {string} label - the label given.
84+
* @property {number} confidence - the confidence score (0 to 1).
85+
* @property {ObjectDetectorPredictionNormalized} normalized - a normalized object of the predicition
86+
*/
87+
88+
/**
89+
* @typedef {Object} ObjectDetectorPredictionNormalized
90+
* @property {number} x - top left x coordinate of the prediction box (0 to 1).
91+
* @property {number} y - top left y coordinate of the prediction box (0 to 1).
92+
* @property {number} width - width of the prediction box (0 to 1).
93+
* @property {number} height - height of the prediction box (0 to 1).
94+
*/
95+
96+
/**
97+
* Detect objects once from the input image/video/canvas.
98+
* @param {HTMLVideoElement|HTMLImageElement|HTMLCanvasElement|ImageData} input - Target element.
99+
* @param {function} cb - Optional callback.
100+
* @returns {ObjectDetectorPrediction}
101+
*/
102+
async detect(input, cb) {
103+
const args = handleArguments(input, cb).require("image", "No valid image input.");
104+
await this.ready;
105+
return callCallback(this.model.detect(args.image), args.callback);
106+
}
107+
108+
/**
109+
* Start continuous detection on video/canvas input
110+
* @param {HTMLVideoElement|HTMLImageElement|HTMLCanvasElement|ImageData} input - Target element.
111+
* @param {function} callback - Callback function called with each detection result.
112+
*/
113+
async detectStart(input, callback) {
114+
const args = handleArguments(input, callback).require("image", "No input provided.");
115+
116+
const detectFrame = async () => {
117+
await mediaReady(args.image, true);
118+
await callCallback(this.model.detect(args.image), args.callback);
48119

49-
let model = string;
50-
// TODO: I think we should delete this.
51-
if (model.indexOf("http") === -1) {
52-
model = model.toLowerCase();
120+
if (!this.signalStop) {
121+
requestAnimationFrame(detectFrame);
122+
} else {
123+
this.isDetecting = false;
124+
}
125+
};
126+
127+
this.signalStop = false;
128+
if (!this.isDetecting) {
129+
this.isDetecting = true;
130+
detectFrame();
131+
}
132+
133+
if (this.prevCall === "start") {
134+
console.warn(
135+
"detectStart() called again without detectStop(). Only the latest call is running."
136+
);
137+
}
138+
139+
this.prevCall = "start";
53140
}
54141

55-
const instance = new ObjectDetector(model, video, options, callback);
142+
detectStop() {
143+
if (this.isDetecting) {
144+
this.signalStop = true;
145+
}
146+
this.prevCall = "stop";
147+
}
148+
}
56149

57-
// return instance.model.callback ? instance.model : instance.model.ready;
58-
return instance.model;
150+
const objectDetector = (modelNameOrUrl, optionsOrCallback, cb) => {
151+
const { string, options = {}, callback } = handleArguments(modelNameOrUrl, optionsOrCallback, cb);
152+
const instance = new ObjectDetector(string, options, callback);
153+
return instance;
59154
};
60155

61156
export default objectDetector;

0 commit comments

Comments
 (0)