Skip to content

Commit 176a2ad

Browse files
committed
add MoveNet pose detection
updated pose detection model from posenet to MoveNet
1 parent 7f4d1bc commit 176a2ad

File tree

5 files changed

+108
-166
lines changed

5 files changed

+108
-166
lines changed

examples/PoseDetection/index.html

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,11 @@
99
<head>
1010
<meta charset="UTF-8" />
1111
<title>PoseNet example using p5.js</title>
12-
13-
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.6.0/p5.min.js"></script>
14-
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.6.0/addons/p5.dom.min.js"></script>
15-
<script
16-
src="https://unpkg.com/[email protected]/dist/ml5.min.js"
17-
type="text/javascript"
18-
></script>
19-
<link rel="stylesheet" type="text/css" href="style.css" />
12+
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/1.6.0/p5.js"></script>
13+
<script src="../../dist/ml5.js"></script>
2014
</head>
2115

2216
<body>
23-
<p id="status">Loading model...</p>
2417
<script src="sketch.js"></script>
2518
</body>
2619
</html>

examples/PoseDetection/sketch.js

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
/* ===
77
ml5 Example
8-
PoseNet example using p5.js
8+
poseDetection example using p5.js
99
=== */
1010

1111
let video;
@@ -18,7 +18,7 @@ function setup() {
1818
video.size(width, height);
1919

2020
// Create a new poseNet method with a single detection
21-
poseNet = ml5.poseNet(video, modelReady);
21+
poseNet = ml5.poseDetection(video, modelReady);
2222
// This sets up an event that fills the global variable "poses"
2323
// with an array every time new poses are detected
2424
poseNet.on("pose", function (results) {
@@ -29,31 +29,31 @@ function setup() {
2929
}
3030

3131
function modelReady() {
32-
select("#status").html("Model Loaded");
32+
console.log("Model Loaded!");
3333
}
3434

3535
function draw() {
3636
image(video, 0, 0, width, height);
37-
37+
//console.log(poses);
3838
// We can call both functions to draw all keypoints and the skeletons
3939
drawKeypoints();
40-
drawSkeleton();
40+
//drawSkeleton();
4141
}
4242

4343
// A function to draw ellipses over the detected keypoints
4444
function drawKeypoints() {
4545
// Loop through all the poses detected
4646
for (let i = 0; i < poses.length; i++) {
4747
// For each pose detected, loop through all the keypoints
48-
let pose = poses[i].pose;
48+
let pose = poses[i];
4949
for (let j = 0; j < pose.keypoints.length; j++) {
5050
// A keypoint is an object describing a body part (like rightArm or leftShoulder)
5151
let keypoint = pose.keypoints[j];
5252
// Only draw an ellipse is the pose probability is bigger than 0.2
5353
if (keypoint.score > 0.2) {
5454
fill(255, 0, 0);
5555
noStroke();
56-
ellipse(keypoint.position.x, keypoint.position.y, 10, 10);
56+
ellipse(keypoint.x, keypoint.y, 10, 10);
5757
}
5858
}
5959
}

package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
},
2828
"dependencies": {
2929
"@mediapipe/hands": "^0.4.1675469240",
30+
"@mediapipe/pose": "^0.5.1675469404",
3031
"@tensorflow-models/hand-pose-detection": "^2.0.0",
32+
"@tensorflow-models/pose-detection": "^2.1.0",
3133
"@tensorflow/tfjs": "^4.2.0",
3234
"@tensorflow/tfjs-vis": "^1.5.1",
3335
"axios": "^1.3.4"

src/PoseDetection/index.js

Lines changed: 79 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -10,183 +10,117 @@ Ported from pose-detection at Tensorflow.js
1010

1111
import EventEmitter from "events";
1212
import * as tf from "@tensorflow/tfjs";
13-
import * as posenet from "@tensorflow-models/posenet";
13+
import * as bodyPoseDetection from "@tensorflow-models/pose-detection";
1414
import callCallback from "../utils/callcallback";
1515
import handleArguments from "../utils/handleArguments";
16+
import { mediaReady } from "../utils/imageUtilities";
1617

17-
const DEFAULTS = {
18-
architecture: "MobileNetV1", // 'MobileNetV1', 'ResNet50'
19-
outputStride: 16, // 8, 16, 32
20-
flipHorizontal: false, // true, false
21-
minConfidence: 0.5,
22-
maxPoseDetections: 5, // any number > 1
23-
scoreThreshold: 0.5,
24-
nmsRadius: 20, // any number > 0
25-
detectionType: "multiple", // 'single'
26-
inputResolution: 256, // or { width: 257, height: 200 }
27-
multiplier: 0.75, // 1.01, 1.0, 0.75, or 0.50 -- only for MobileNet
28-
quantBytes: 2, // 4, 2, 1
29-
modelUrl: null, // url path to model
30-
};
31-
32-
class PoseNet extends EventEmitter {
18+
class PoseDetection extends EventEmitter {
3319
/**
3420
* @typedef {Object} options
35-
* @property {string} architecture - default 'MobileNetV1',
36-
* @property {number} inputResolution - default 257,
37-
* @property {number} outputStride - default 16
38-
* @property {boolean} flipHorizontal - default false
39-
* @property {number} minConfidence - default 0.5
40-
* @property {number} maxPoseDetections - default 5
41-
* @property {number} scoreThreshold - default 0.5
42-
* @property {number} nmsRadius - default 20
43-
* @property {String} detectionType - default single
44-
* @property {number} nmsRadius - default 0.75,
45-
* @property {number} quantBytes - default 2,
46-
* @property {string} modelUrl - default null
21+
* @property {string} modelType - Optional. specify what model variant to load from. Default: 'SINGLEPOSE_LIGHTNING'.
22+
* @property {boolean} enableSmoothing - Optional. Whether to use temporal filter to smooth keypoints across frames. Default: true.
23+
* @property {string} modelUrl - Optional. A string that specifies custom url of the model. Default to load from tf.hub.
24+
* @property {number} minPoseScore - Optional. The minimum confidence score for a pose to be detected. Default: 0.25.
25+
* @property {number} multiPoseMaxDimension - Optional. The target maximum dimension to use as the input to the multi-pose model. Must be a mutiple of 32. Default: 256.
26+
* @property {boolean} enableTracking - Optional. Track each person across the frame with a unique ID. Default: true.
27+
* @property {string} trackerType - Optional. Specify what type of tracker to use. Default: 'boundingBox'.
28+
* @property {Object} trackerConfig - Optional. Specify tracker configurations. Use tf.js setting by default.
4729
*/
30+
4831
/**
4932
* Create a PoseNet model.
5033
* @param {HTMLVideoElement || p5.Video} video - Optional. A HTML video element or a p5 video element.
5134
* @param {options} options - Optional. An object describing a model accuracy and performance.
52-
* @param {String} detectionType - Optional. A String value to run 'single' or 'multiple' estimation.
5335
* @param {function} callback Optional. A function to run once the model has been loaded.
5436
* If no callback is provided, it will return a promise that will be resolved once the
5537
* model has loaded.
5638
*/
57-
constructor(video, options, detectionType, callback) {
39+
constructor(video, options, callback) {
5840
super();
41+
5942
this.video = video;
60-
/**
61-
* The type of detection. 'single' or 'multiple'
62-
* @type {String}
63-
* @public
64-
*/
65-
this.modelUrl = options.modelUrl || null;
66-
this.architecture = options.architecture || DEFAULTS.architecture;
67-
this.detectionType =
68-
detectionType || options.detectionType || DEFAULTS.detectionType;
69-
this.outputStride = options.outputStride || DEFAULTS.outputStride;
70-
this.flipHorizontal = options.flipHorizontal || DEFAULTS.flipHorizontal;
71-
this.scoreThreshold = options.scoreThreshold || DEFAULTS.scoreThreshold;
72-
this.minConfidence = options.minConfidence || DEFAULTS.minConfidence;
73-
this.maxPoseDetections =
74-
options.maxPoseDetections || DEFAULTS.maxPoseDetections;
75-
this.multiplier = options.multiplier || DEFAULTS.multiplier;
76-
this.inputResolution = options.inputResolution || DEFAULTS.inputResolution;
77-
this.quantBytes = options.quantBytes || DEFAULTS.quantBytes;
78-
this.nmsRadius = options.nmsRadius || DEFAULTS.nmsRadius;
79-
this.ready = callCallback(this.load(), callback);
80-
// this.then = this.ready.then;
43+
this.model = null;
44+
this.modelReady = false;
45+
this.config = options;
46+
47+
this.ready = callCallback(this.loadModel(), callback);
8148
}
8249

83-
async load() {
84-
let modelJson;
85-
if (this.architecture.toLowerCase() === "mobilenetv1") {
86-
modelJson = {
87-
architecture: this.architecture,
88-
outputStride: this.outputStride,
89-
inputResolution: this.inputResolution,
90-
multiplier: this.multiplier,
91-
quantBytes: this.quantBytes,
92-
modelUrl: this.modelUrl,
93-
};
94-
} else {
95-
modelJson = {
96-
architecture: this.architecture,
97-
outputStride: this.outputStride,
98-
inputResolution: this.inputResolution,
99-
quantBytes: this.quantBytes,
100-
};
50+
/**
51+
* Load the model and set it to this.model
52+
* @return {this} the detector model.
53+
*/
54+
async loadModel() {
55+
const pipeline = bodyPoseDetection.SupportedModels.MoveNet;
56+
//Set the config to user defined or default values
57+
const modelConfig = {
58+
enableSmoothing: this.config.enableSmoothing ?? true,
59+
modelUrl: this.config.modelUrl,
60+
minPoseScore: this.config.minPoseScore ?? 0.25,
61+
multiPoseMaxDimension: this.config.multiPoseMaxDimension ?? 256,
62+
enableTracking: this.config.enableTracking ?? true,
63+
trackerType: this.config.trackerType ?? "boundingBox",
64+
trackerConfig: this.config.trackerConfig,
65+
};
66+
switch (this.config.modelType) {
67+
case "SINGLEPOSE_LIGHTNING":
68+
modelConfig.modelType =
69+
bodyPoseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING;
70+
break;
71+
case "SINGLEPOSE_THUNDER":
72+
modelConfig.modelType =
73+
bodyPoseDetection.movenet.modelType.SINGLEPOSE_THUNDER;
74+
case "MULTIPOSE_LIGHTNING":
75+
modelConfig.modelType =
76+
bodyPoseDetection.movenet.modelType.MULTIPOSE_LIGHTNING;
10177
}
102-
103-
this.net = await posenet.load(modelJson);
78+
// Load the detector model
79+
await tf.setBackend("webgl");
80+
this.model = await bodyPoseDetection.createDetector(pipeline, modelConfig);
81+
this.modelReady = true;
10482

10583
if (this.video) {
106-
if (this.video.readyState === 0) {
107-
await new Promise((resolve) => {
108-
this.video.onloadeddata = () => resolve();
109-
});
110-
}
111-
if (this.detectionType === "single") {
112-
this.singlePose();
113-
} else {
114-
this.multiPose();
115-
}
84+
this.predict();
11685
}
117-
return this;
118-
}
11986

120-
skeleton(keypoints, confidence = this.minConfidence) {
121-
return posenet.getAdjacentKeyPoints(keypoints, confidence);
87+
return this;
12288
}
12389

124-
// eslint-disable-next-line class-methods-use-this
125-
mapParts(pose) {
126-
const newPose = JSON.parse(JSON.stringify(pose));
127-
newPose.keypoints.forEach((keypoint) => {
128-
newPose[keypoint.part] = {
129-
x: keypoint.position.x,
130-
y: keypoint.position.y,
131-
confidence: keypoint.score,
132-
};
133-
});
134-
return newPose;
135-
}
90+
//Add named keypoints to a MoveNet pose object
91+
// mapParts(pose) {
92+
// const newPose = JSON.parse(JSON.stringify(pose));
93+
// newPose.keypoints.forEach((keypoint) => {
94+
// newPose[keypoint.part] = {
95+
// x: keypoint.position.x,
96+
// y: keypoint.position.y,
97+
// confidence: keypoint.score,
98+
// };
99+
// });
100+
// return newPose;
101+
// }
136102

137103
/**
138104
* Given an image or video, returns an array of objects containing pose estimations
139-
* using single or multi-pose detection.
140-
* @param {HTMLVideoElement || p5.Video || function} inputOr
141-
* @param {function} cb
105+
* @param {HTMLVideoElement || p5.Video || function} inputOr - An HMTL or p5.js image, video, or canvas element to run the prediction on.
106+
* @param {function} cb - A callback function to handle the predictions.
142107
*/
143-
async singlePose(inputOr, cb) {
144-
const { image: input, callback } = handleArguments(this.video, inputOr, cb);
145-
146-
const pose = await this.net.estimateSinglePose(input, {
147-
flipHorizontal: this.flipHorizontal,
148-
});
149-
const poseWithParts = this.mapParts(pose);
150-
const result = [
151-
{ pose: poseWithParts, skeleton: this.skeleton(pose.keypoints) },
152-
];
153-
this.emit("pose", result);
154-
155-
if (this.video) {
156-
return tf.nextFrame().then(() => this.singlePose());
157-
}
158-
159-
if (typeof callback === "function") {
160-
callback(result);
108+
async predict(inputOr, cb) {
109+
const { image, callback } = handleArguments(this.video, inputOr, cb);
110+
if (!image) {
111+
throw new Error("No input image found.");
161112
}
113+
// If video is provided, wait for video to be loaded
114+
await mediaReady(image, false);
115+
const result = await this.model.estimatePoses(image);
162116

163-
return result;
164-
}
117+
//Add named keypoints to each pose object
118+
//const result = poses.map((pose) => this.mapParts(pose));
165119

166-
/**
167-
* Given an image or video, returns an array of objects containing pose
168-
* estimations using single or multi-pose detection.
169-
* @param {HTMLVideoElement || p5.Video || function} inputOr
170-
* @param {function} cb
171-
*/
172-
async multiPose(inputOr, cb) {
173-
const { image: input, callback } = handleArguments(this.video, inputOr, cb);
174-
175-
const poses = await this.net.estimateMultiplePoses(input, {
176-
flipHorizontal: this.flipHorizontal,
177-
maxDetections: this.maxPoseDetections,
178-
scoreThreshold: this.scoreThreshold,
179-
nmsRadius: this.nmsRadius,
180-
});
181-
182-
const posesWithParts = poses.map((pose) => this.mapParts(pose));
183-
const result = posesWithParts.map((pose) => ({
184-
pose,
185-
skeleton: this.skeleton(pose.keypoints),
186-
}));
187120
this.emit("pose", result);
121+
188122
if (this.video) {
189-
return tf.nextFrame().then(() => this.multiPose());
123+
return tf.nextFrame().then(() => this.predict());
190124
}
191125

192126
if (typeof callback === "function") {
@@ -198,13 +132,8 @@ class PoseNet extends EventEmitter {
198132
}
199133

200134
const poseDetection = (...inputs) => {
201-
const {
202-
video,
203-
options = {},
204-
callback,
205-
string: detectionType,
206-
} = handleArguments(...inputs);
207-
return new PoseNet(video, options, detectionType, callback);
135+
const { video, options = {}, callback } = handleArguments(...inputs);
136+
return new PoseDetection(video, options, callback);
208137
};
209138

210139
export default poseDetection;

0 commit comments

Comments
 (0)