Skip to content

Commit 2e70dd3

Browse files
authored
Merge pull request #27 from ml5js/model-pose-detection
Model pose detection
2 parents 0d409c8 + 6867e88 commit 2e70dd3

File tree

7 files changed

+289
-0
lines changed

7 files changed

+289
-0
lines changed

examples/PoseDetection/index.html

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<!--
2+
Copyright (c) 2018 ml5
3+
4+
This software is released under the MIT License.
5+
https://opensource.org/licenses/MIT
6+
-->
7+
8+
<html>
9+
<head>
10+
<meta charset="UTF-8" />
11+
<title>Pose detection example using p5.js</title>
12+
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/1.6.0/p5.js"></script>
13+
<script src="../../dist/ml5.js"></script>
14+
</head>
15+
16+
<body>
17+
<script src="sketch.js"></script>
18+
</body>
19+
</html>

examples/PoseDetection/sketch.js

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright (c) 2018 ml5
2+
//
3+
// This software is released under the MIT License.
4+
// https://opensource.org/licenses/MIT
5+
6+
/* ===
7+
ml5 Example
8+
poseDetection example using p5.js
9+
=== */
10+
11+
let video;
12+
let poseNet;
13+
let poses = [];
14+
15+
function setup() {
16+
createCanvas(640, 480);
17+
18+
// Create the video and hide it
19+
video = createCapture(VIDEO);
20+
video.size(width, height);
21+
video.hide();
22+
23+
// Load the model and attach an event
24+
poseDetector = ml5.poseDetection(video, modelReady);
25+
poseDetector.on("pose", gotPoses);
26+
}
27+
28+
// Event for pose detection
29+
function gotPoses(results) {
30+
// Always save the latest output from the model in global variable "poses"
31+
poses = results;
32+
}
33+
34+
// Event for when model loaded
35+
function modelReady() {
36+
console.log("Model ready!");
37+
}
38+
39+
function draw() {
40+
console.log(poses);
41+
// Draw the video
42+
image(video, 0, 0, width, height);
43+
44+
// Draw all the tracked landmark points
45+
// for each individual pose detected
46+
for (let i = 0; i < poses.length; i++) {
47+
let pose = poses[i];
48+
// for each keypoint in the pose
49+
for (let j = 0; j < pose.keypoints.length; j++) {
50+
let keypoint = pose.keypoints[j];
51+
// Only draw an ellipse if the confidence score of the keypoint is bigger than 0.2
52+
if (keypoint.score > 0.2) {
53+
fill(255, 0, 0);
54+
noStroke();
55+
ellipse(keypoint.x, keypoint.y, 10, 10);
56+
}
57+
}
58+
}
59+
}

package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
},
2828
"dependencies": {
2929
"@mediapipe/hands": "^0.4.1675469240",
30+
"@mediapipe/pose": "^0.5.1675469404",
3031
"@tensorflow-models/hand-pose-detection": "^2.0.0",
32+
"@tensorflow-models/pose-detection": "^2.1.0",
3133
"@tensorflow/tfjs": "^4.2.0",
3234
"@tensorflow/tfjs-vis": "^1.5.1",
3335
"axios": "^1.3.4"

src/PoseDetection/index.js

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Copyright (c) 2018 ml5
2+
//
3+
// This software is released under the MIT License.
4+
// https://opensource.org/licenses/MIT
5+
6+
/*
7+
PoseDetection
8+
Ported from pose-detection at Tensorflow.js
9+
*/
10+
11+
import EventEmitter from "events";
12+
import * as tf from "@tensorflow/tfjs";
13+
import * as bodyPoseDetection from "@tensorflow-models/pose-detection";
14+
import callCallback from "../utils/callcallback";
15+
import handleArguments from "../utils/handleArguments";
16+
import { mediaReady } from "../utils/imageUtilities";
17+
18+
class PoseDetection extends EventEmitter {
19+
/**
20+
* @typedef {Object} options
21+
* @property {string} modelType - Optional. specify what model variant to load from. Default: 'MULTIPOSE_LIGHTNING'.
22+
* @property {boolean} enableSmoothing - Optional. Whether to use temporal filter to smooth keypoints across frames. Default: true.
23+
* @property {string} modelUrl - Optional. A string that specifies custom url of the model. Default to load from tf.hub.
24+
* @property {number} minPoseScore - Optional. The minimum confidence score for a pose to be detected. Default: 0.25.
25+
* @property {number} multiPoseMaxDimension - Optional. The target maximum dimension to use as the input to the multi-pose model. Must be a mutiple of 32. Default: 256.
26+
* @property {boolean} enableTracking - Optional. Track each person across the frame with a unique ID. Default: true.
27+
* @property {string} trackerType - Optional. Specify what type of tracker to use. Default: 'boundingBox'.
28+
* @property {Object} trackerConfig - Optional. Specify tracker configurations. Use tf.js setting by default.
29+
*/
30+
31+
/**
32+
* Create a PoseNet model.
33+
* @param {HTMLVideoElement || p5.Video} video - Optional. A HTML video element or a p5 video element.
34+
* @param {options} options - Optional. An object describing a model accuracy and performance.
35+
* @param {function} callback Optional. A function to run once the model has been loaded.
36+
* If no callback is provided, it will return a promise that will be resolved once the
37+
* model has loaded.
38+
*/
39+
constructor(video, options, callback) {
40+
super();
41+
42+
this.video = video;
43+
this.model = null;
44+
this.modelReady = false;
45+
this.config = options;
46+
47+
this.ready = callCallback(this.loadModel(), callback);
48+
}
49+
50+
/**
51+
* Load the model and set it to this.model
52+
* @return {this} the detector model.
53+
*/
54+
async loadModel() {
55+
const pipeline = bodyPoseDetection.SupportedModels.MoveNet;
56+
//Set the config to user defined or default values
57+
const modelConfig = {
58+
enableSmoothing: this.config.enableSmoothing ?? true,
59+
modelUrl: this.config.modelUrl,
60+
minPoseScore: this.config.minPoseScore ?? 0.25,
61+
multiPoseMaxDimension: this.config.multiPoseMaxDimension ?? 256,
62+
enableTracking: this.config.enableTracking ?? true,
63+
trackerType: this.config.trackerType ?? "boundingBox",
64+
trackerConfig: this.config.trackerConfig,
65+
};
66+
// use multi-pose lightning model by default
67+
switch (this.config.modelType) {
68+
case "SINGLEPOSE_LIGHTNING":
69+
modelConfig.modelType =
70+
bodyPoseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING;
71+
break;
72+
case "SINGLEPOSE_THUNDER":
73+
modelConfig.modelType =
74+
bodyPoseDetection.movenet.modelType.SINGLEPOSE_THUNDER;
75+
break;
76+
default:
77+
modelConfig.modelType =
78+
bodyPoseDetection.movenet.modelType.MULTIPOSE_LIGHTNING;
79+
}
80+
// Load the detector model
81+
await tf.setBackend("webgl");
82+
this.model = await bodyPoseDetection.createDetector(pipeline, modelConfig);
83+
this.modelReady = true;
84+
85+
if (this.video) {
86+
this.predict();
87+
}
88+
89+
return this;
90+
}
91+
92+
//TODO: Add named keypoints to a MoveNet pose object
93+
94+
/**
95+
* Given an image or video, returns an array of objects containing pose estimations
96+
* @param {HTMLVideoElement || p5.Video || function} inputOr - An HMTL or p5.js image, video, or canvas element to run the prediction on.
97+
* @param {function} cb - A callback function to handle the predictions.
98+
*/
99+
async predict(inputOr, cb) {
100+
const { image, callback } = handleArguments(this.video, inputOr, cb);
101+
if (!image) {
102+
throw new Error("No input image found.");
103+
}
104+
// If video is provided, wait for video to be loaded
105+
await mediaReady(image, false);
106+
const result = await this.model.estimatePoses(image);
107+
108+
// TODO: Add named keypoints to each pose object
109+
110+
this.emit("pose", result);
111+
112+
if (this.video) {
113+
return tf.nextFrame().then(() => this.predict());
114+
}
115+
116+
if (typeof callback === "function") {
117+
callback(result);
118+
}
119+
120+
return result;
121+
}
122+
}
123+
124+
const poseDetection = (...inputs) => {
125+
const { video, options = {}, callback } = handleArguments(...inputs);
126+
return new PoseDetection(video, options, callback);
127+
};
128+
129+
export default poseDetection;

src/PoseDetection/index.test.js

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright (c) 2018 ml5
2+
//
3+
// This software is released under the MIT License.
4+
// https://opensource.org/licenses/MIT
5+
6+
import { asyncLoadImage } from "../utils/testingUtils";
7+
import poseNet from "./index";
8+
9+
const POSENET_IMG =
10+
"https://github.com/ml5js/ml5-adjacent/raw/master/02_ImageClassification_Video/starter.png";
11+
12+
const POSENET_DEFAULTS = {
13+
architecture: "MobileNetV1",
14+
outputStride: 16,
15+
flipHorizontal: false,
16+
minConfidence: 0.5,
17+
maxPoseDetections: 5,
18+
scoreThreshold: 0.5,
19+
nmsRadius: 20,
20+
detectionType: "multiple",
21+
inputResolution: 256,
22+
multiplier: 0.75,
23+
quantBytes: 2,
24+
};
25+
26+
describe("PoseNet", () => {
27+
let net;
28+
29+
beforeAll(async () => {
30+
jest.setTimeout(10000);
31+
net = await poseNet();
32+
});
33+
34+
it("instantiates poseNet", () => {
35+
expect(net.architecture).toBe(POSENET_DEFAULTS.architecture);
36+
expect(net.outputStride).toBe(POSENET_DEFAULTS.outputStride);
37+
expect(net.inputResolution).toBe(POSENET_DEFAULTS.inputResolution);
38+
expect(net.multiplier).toBe(POSENET_DEFAULTS.multiplier);
39+
expect(net.quantBytes).toBe(POSENET_DEFAULTS.quantBytes);
40+
});
41+
42+
it("detects poses in image", async () => {
43+
const image = await asyncLoadImage(POSENET_IMG);
44+
45+
// Result should be an array with a single object containing pose and skeleton.
46+
const result = await net.singlePose(image);
47+
expect(result).toHaveLength(1);
48+
expect(result[0]).toHaveProperty("pose");
49+
expect(result[0]).toHaveProperty("skeleton");
50+
51+
// Verify a known outcome.
52+
const nose = result[0].pose.keypoints.find(
53+
(keypoint) => keypoint.part === "nose"
54+
);
55+
expect(nose).toBeTruthy();
56+
expect(nose.position.x).toBeCloseTo(448.6, 0);
57+
expect(nose.position.y).toBeCloseTo(255.9, 0);
58+
expect(nose.score).toBeCloseTo(0.999);
59+
});
60+
});

src/index.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import neuralNetwork from "./NeuralNetwork";
22
import handpose from "./Handpose";
3+
import poseDetection from "./PoseDetection";
34
import * as tf from "@tensorflow/tfjs";
45
import * as tfvis from "@tensorflow/tfjs-vis";
56
import p5Utils from "./utils/p5Utils";
@@ -11,5 +12,6 @@ export default Object.assign(
1112
tfvis,
1213
neuralNetwork,
1314
handpose,
15+
poseDetection,
1416
}
1517
);

yarn.lock

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,26 @@
6262
resolved "https://registry.yarnpkg.com/@mediapipe/hands/-/hands-0.4.1675469240.tgz#f032b2f5deff5a69430693f94be45dd9854e803f"
6363
integrity sha512-GxoZvL1mmhJxFxjuyj7vnC++JIuInGznHBin5c7ZSq/RbcnGyfEcJrkM/bMu5K1Mz/2Ko+vEX6/+wewmEHPrHg==
6464

65+
"@mediapipe/pose@^0.5.1675469404":
66+
version "0.5.1675469404"
67+
resolved "https://registry.yarnpkg.com/@mediapipe/pose/-/pose-0.5.1675469404.tgz#8f81e64c6561b2357a021a134b54de0204bafc72"
68+
integrity sha512-DFZsNWTsSphRIZppnUCuunzBiHP2FdJXR9ehc7mMi4KG+oPaOH0Em3d6kr7Py+TSyTXC1doH88KcF28k2sBxsQ==
69+
6570
"@tensorflow-models/hand-pose-detection@^2.0.0":
6671
version "2.0.0"
6772
resolved "https://registry.yarnpkg.com/@tensorflow-models/hand-pose-detection/-/hand-pose-detection-2.0.0.tgz#967b26d17d26454d0625c2af2264fd7aad8fdc35"
6873
integrity sha512-wAiu/SpigjKuhlEdIvPp84FyzIH0v8kHn/jB/VslUn/pV75Kpsv8Jk0S55oC/Jj54B/fLDZU19+zYN7lQMBCxg==
6974
dependencies:
7075
rimraf "^3.0.2"
7176

77+
"@tensorflow-models/pose-detection@^2.1.0":
78+
version "2.1.0"
79+
resolved "https://registry.yarnpkg.com/@tensorflow-models/pose-detection/-/pose-detection-2.1.0.tgz#733ce55dfe4a75d40cb04935e461beacac2e8e4d"
80+
integrity sha512-4WOgxiPuA1ymZff9Epez2GsC0FSv7Fj8olu5LzXB/JAYo/zgyu2PFOQFeDXvvqyXwxdB4IrYYsgerUqlZaGdSQ==
81+
dependencies:
82+
rimraf "^3.0.2"
83+
tslib "2.4.0"
84+
7285
"@tensorflow/[email protected]":
7386
version "4.8.0"
7487
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-4.8.0.tgz#6281ab0a93400f2b5c7b2efa07b0befb895a0260"
@@ -2880,6 +2893,11 @@ tr46@~0.0.3:
28802893
resolved "https://registry.yarnpkg.com/tr46/-/tr46-0.0.3.tgz#8184fd347dac9cdc185992f3a6622e14b9d9ab6a"
28812894
integrity sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==
28822895

2896+
2897+
version "2.4.0"
2898+
resolved "https://registry.yarnpkg.com/tslib/-/tslib-2.4.0.tgz#7cecaa7f073ce680a05847aa77be941098f36dc3"
2899+
integrity sha512-d6xOpEDfsi2CZVlPQzGeux8XMwLT9hssAsaPYExaQMuYskwb+x1x7J371tWlbBdWHroy99KnVB6qIkUbs5X3UQ==
2900+
28832901
tslib@^2.0.3:
28842902
version "2.6.0"
28852903
resolved "https://registry.yarnpkg.com/tslib/-/tslib-2.6.0.tgz#b295854684dbda164e181d259a22cd779dcd7bc3"

0 commit comments

Comments
 (0)