Skip to content

Commit 013b498

Browse files
committed
Bring in old posenet code
1 parent b751612 commit 013b498

File tree

5 files changed

+378
-0
lines changed

5 files changed

+378
-0
lines changed

examples/PoseDetection/index.html

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
<!--
2+
Copyright (c) 2018 ml5
3+
4+
This software is released under the MIT License.
5+
https://opensource.org/licenses/MIT
6+
-->
7+
8+
<html>
9+
<head>
10+
<meta charset="UTF-8" />
11+
<title>PoseNet example using p5.js</title>
12+
13+
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.6.0/p5.min.js"></script>
14+
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.6.0/addons/p5.dom.min.js"></script>
15+
<script
16+
src="https://unpkg.com/[email protected]/dist/ml5.min.js"
17+
type="text/javascript"
18+
></script>
19+
<link rel="stylesheet" type="text/css" href="style.css" />
20+
</head>
21+
22+
<body>
23+
<p id="status">Loading model...</p>
24+
<script src="sketch.js"></script>
25+
</body>
26+
</html>

examples/PoseDetection/sketch.js

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) 2018 ml5
2+
//
3+
// This software is released under the MIT License.
4+
// https://opensource.org/licenses/MIT
5+
6+
/* ===
7+
ml5 Example
8+
PoseNet example using p5.js
9+
=== */
10+
11+
let video;
12+
let poseNet;
13+
let poses = [];
14+
15+
function setup() {
16+
createCanvas(640, 480);
17+
video = createCapture(VIDEO);
18+
video.size(width, height);
19+
20+
// Create a new poseNet method with a single detection
21+
poseNet = ml5.poseNet(video, modelReady);
22+
// This sets up an event that fills the global variable "poses"
23+
// with an array every time new poses are detected
24+
poseNet.on("pose", function (results) {
25+
poses = results;
26+
});
27+
// Hide the video element, and just show the canvas
28+
video.hide();
29+
}
30+
31+
function modelReady() {
32+
select("#status").html("Model Loaded");
33+
}
34+
35+
function draw() {
36+
image(video, 0, 0, width, height);
37+
38+
// We can call both functions to draw all keypoints and the skeletons
39+
drawKeypoints();
40+
drawSkeleton();
41+
}
42+
43+
// A function to draw ellipses over the detected keypoints
44+
function drawKeypoints() {
45+
// Loop through all the poses detected
46+
for (let i = 0; i < poses.length; i++) {
47+
// For each pose detected, loop through all the keypoints
48+
let pose = poses[i].pose;
49+
for (let j = 0; j < pose.keypoints.length; j++) {
50+
// A keypoint is an object describing a body part (like rightArm or leftShoulder)
51+
let keypoint = pose.keypoints[j];
52+
// Only draw an ellipse is the pose probability is bigger than 0.2
53+
if (keypoint.score > 0.2) {
54+
fill(255, 0, 0);
55+
noStroke();
56+
ellipse(keypoint.position.x, keypoint.position.y, 10, 10);
57+
}
58+
}
59+
}
60+
}
61+
62+
// A function to draw the skeletons
63+
function drawSkeleton() {
64+
// Loop through all the skeletons detected
65+
for (let i = 0; i < poses.length; i++) {
66+
let skeleton = poses[i].skeleton;
67+
// For every skeleton, loop through all body connections
68+
for (let j = 0; j < skeleton.length; j++) {
69+
let partA = skeleton[j][0];
70+
let partB = skeleton[j][1];
71+
stroke(255, 0, 0);
72+
line(
73+
partA.position.x,
74+
partA.position.y,
75+
partB.position.x,
76+
partB.position.y
77+
);
78+
}
79+
}
80+
}

src/PoseDetection/index.js

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
// Copyright (c) 2018 ml5
2+
//
3+
// This software is released under the MIT License.
4+
// https://opensource.org/licenses/MIT
5+
6+
/*
7+
PoseDetection
8+
Ported from pose-detection at Tensorflow.js
9+
*/
10+
11+
import EventEmitter from "events";
12+
import * as tf from "@tensorflow/tfjs";
13+
import * as posenet from "@tensorflow-models/posenet";
14+
import callCallback from "../utils/callcallback";
15+
import handleArguments from "../utils/handleArguments";
16+
17+
const DEFAULTS = {
18+
architecture: "MobileNetV1", // 'MobileNetV1', 'ResNet50'
19+
outputStride: 16, // 8, 16, 32
20+
flipHorizontal: false, // true, false
21+
minConfidence: 0.5,
22+
maxPoseDetections: 5, // any number > 1
23+
scoreThreshold: 0.5,
24+
nmsRadius: 20, // any number > 0
25+
detectionType: "multiple", // 'single'
26+
inputResolution: 256, // or { width: 257, height: 200 }
27+
multiplier: 0.75, // 1.01, 1.0, 0.75, or 0.50 -- only for MobileNet
28+
quantBytes: 2, // 4, 2, 1
29+
modelUrl: null, // url path to model
30+
};
31+
32+
class PoseNet extends EventEmitter {
33+
/**
34+
* @typedef {Object} options
35+
* @property {string} architecture - default 'MobileNetV1',
36+
* @property {number} inputResolution - default 257,
37+
* @property {number} outputStride - default 16
38+
* @property {boolean} flipHorizontal - default false
39+
* @property {number} minConfidence - default 0.5
40+
* @property {number} maxPoseDetections - default 5
41+
* @property {number} scoreThreshold - default 0.5
42+
* @property {number} nmsRadius - default 20
43+
* @property {String} detectionType - default single
44+
* @property {number} nmsRadius - default 0.75,
45+
* @property {number} quantBytes - default 2,
46+
* @property {string} modelUrl - default null
47+
*/
48+
/**
49+
* Create a PoseNet model.
50+
* @param {HTMLVideoElement || p5.Video} video - Optional. A HTML video element or a p5 video element.
51+
* @param {options} options - Optional. An object describing a model accuracy and performance.
52+
* @param {String} detectionType - Optional. A String value to run 'single' or 'multiple' estimation.
53+
* @param {function} callback Optional. A function to run once the model has been loaded.
54+
* If no callback is provided, it will return a promise that will be resolved once the
55+
* model has loaded.
56+
*/
57+
constructor(video, options, detectionType, callback) {
58+
super();
59+
this.video = video;
60+
/**
61+
* The type of detection. 'single' or 'multiple'
62+
* @type {String}
63+
* @public
64+
*/
65+
this.modelUrl = options.modelUrl || null;
66+
this.architecture = options.architecture || DEFAULTS.architecture;
67+
this.detectionType =
68+
detectionType || options.detectionType || DEFAULTS.detectionType;
69+
this.outputStride = options.outputStride || DEFAULTS.outputStride;
70+
this.flipHorizontal = options.flipHorizontal || DEFAULTS.flipHorizontal;
71+
this.scoreThreshold = options.scoreThreshold || DEFAULTS.scoreThreshold;
72+
this.minConfidence = options.minConfidence || DEFAULTS.minConfidence;
73+
this.maxPoseDetections =
74+
options.maxPoseDetections || DEFAULTS.maxPoseDetections;
75+
this.multiplier = options.multiplier || DEFAULTS.multiplier;
76+
this.inputResolution = options.inputResolution || DEFAULTS.inputResolution;
77+
this.quantBytes = options.quantBytes || DEFAULTS.quantBytes;
78+
this.nmsRadius = options.nmsRadius || DEFAULTS.nmsRadius;
79+
this.ready = callCallback(this.load(), callback);
80+
// this.then = this.ready.then;
81+
}
82+
83+
async load() {
84+
let modelJson;
85+
if (this.architecture.toLowerCase() === "mobilenetv1") {
86+
modelJson = {
87+
architecture: this.architecture,
88+
outputStride: this.outputStride,
89+
inputResolution: this.inputResolution,
90+
multiplier: this.multiplier,
91+
quantBytes: this.quantBytes,
92+
modelUrl: this.modelUrl,
93+
};
94+
} else {
95+
modelJson = {
96+
architecture: this.architecture,
97+
outputStride: this.outputStride,
98+
inputResolution: this.inputResolution,
99+
quantBytes: this.quantBytes,
100+
};
101+
}
102+
103+
this.net = await posenet.load(modelJson);
104+
105+
if (this.video) {
106+
if (this.video.readyState === 0) {
107+
await new Promise((resolve) => {
108+
this.video.onloadeddata = () => resolve();
109+
});
110+
}
111+
if (this.detectionType === "single") {
112+
this.singlePose();
113+
} else {
114+
this.multiPose();
115+
}
116+
}
117+
return this;
118+
}
119+
120+
skeleton(keypoints, confidence = this.minConfidence) {
121+
return posenet.getAdjacentKeyPoints(keypoints, confidence);
122+
}
123+
124+
// eslint-disable-next-line class-methods-use-this
125+
mapParts(pose) {
126+
const newPose = JSON.parse(JSON.stringify(pose));
127+
newPose.keypoints.forEach((keypoint) => {
128+
newPose[keypoint.part] = {
129+
x: keypoint.position.x,
130+
y: keypoint.position.y,
131+
confidence: keypoint.score,
132+
};
133+
});
134+
return newPose;
135+
}
136+
137+
/**
138+
* Given an image or video, returns an array of objects containing pose estimations
139+
* using single or multi-pose detection.
140+
* @param {HTMLVideoElement || p5.Video || function} inputOr
141+
* @param {function} cb
142+
*/
143+
async singlePose(inputOr, cb) {
144+
const { image: input, callback } = handleArguments(this.video, inputOr, cb);
145+
146+
const pose = await this.net.estimateSinglePose(input, {
147+
flipHorizontal: this.flipHorizontal,
148+
});
149+
const poseWithParts = this.mapParts(pose);
150+
const result = [
151+
{ pose: poseWithParts, skeleton: this.skeleton(pose.keypoints) },
152+
];
153+
this.emit("pose", result);
154+
155+
if (this.video) {
156+
return tf.nextFrame().then(() => this.singlePose());
157+
}
158+
159+
if (typeof callback === "function") {
160+
callback(result);
161+
}
162+
163+
return result;
164+
}
165+
166+
/**
167+
* Given an image or video, returns an array of objects containing pose
168+
* estimations using single or multi-pose detection.
169+
* @param {HTMLVideoElement || p5.Video || function} inputOr
170+
* @param {function} cb
171+
*/
172+
async multiPose(inputOr, cb) {
173+
const { image: input, callback } = handleArguments(this.video, inputOr, cb);
174+
175+
const poses = await this.net.estimateMultiplePoses(input, {
176+
flipHorizontal: this.flipHorizontal,
177+
maxDetections: this.maxPoseDetections,
178+
scoreThreshold: this.scoreThreshold,
179+
nmsRadius: this.nmsRadius,
180+
});
181+
182+
const posesWithParts = poses.map((pose) => this.mapParts(pose));
183+
const result = posesWithParts.map((pose) => ({
184+
pose,
185+
skeleton: this.skeleton(pose.keypoints),
186+
}));
187+
this.emit("pose", result);
188+
if (this.video) {
189+
return tf.nextFrame().then(() => this.multiPose());
190+
}
191+
192+
if (typeof callback === "function") {
193+
callback(result);
194+
}
195+
196+
return result;
197+
}
198+
}
199+
200+
const poseDetection = (...inputs) => {
201+
const {
202+
video,
203+
options = {},
204+
callback,
205+
string: detectionType,
206+
} = handleArguments(...inputs);
207+
return new PoseNet(video, options, detectionType, callback);
208+
};
209+
210+
export default poseDetection;

src/PoseDetection/index.test.js

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright (c) 2018 ml5
2+
//
3+
// This software is released under the MIT License.
4+
// https://opensource.org/licenses/MIT
5+
6+
import { asyncLoadImage } from "../utils/testingUtils";
7+
import poseNet from "./index";
8+
9+
const POSENET_IMG =
10+
"https://github.com/ml5js/ml5-adjacent/raw/master/02_ImageClassification_Video/starter.png";
11+
12+
const POSENET_DEFAULTS = {
13+
architecture: "MobileNetV1",
14+
outputStride: 16,
15+
flipHorizontal: false,
16+
minConfidence: 0.5,
17+
maxPoseDetections: 5,
18+
scoreThreshold: 0.5,
19+
nmsRadius: 20,
20+
detectionType: "multiple",
21+
inputResolution: 256,
22+
multiplier: 0.75,
23+
quantBytes: 2,
24+
};
25+
26+
describe("PoseNet", () => {
27+
let net;
28+
29+
beforeAll(async () => {
30+
jest.setTimeout(10000);
31+
net = await poseNet();
32+
});
33+
34+
it("instantiates poseNet", () => {
35+
expect(net.architecture).toBe(POSENET_DEFAULTS.architecture);
36+
expect(net.outputStride).toBe(POSENET_DEFAULTS.outputStride);
37+
expect(net.inputResolution).toBe(POSENET_DEFAULTS.inputResolution);
38+
expect(net.multiplier).toBe(POSENET_DEFAULTS.multiplier);
39+
expect(net.quantBytes).toBe(POSENET_DEFAULTS.quantBytes);
40+
});
41+
42+
it("detects poses in image", async () => {
43+
const image = await asyncLoadImage(POSENET_IMG);
44+
45+
// Result should be an array with a single object containing pose and skeleton.
46+
const result = await net.singlePose(image);
47+
expect(result).toHaveLength(1);
48+
expect(result[0]).toHaveProperty("pose");
49+
expect(result[0]).toHaveProperty("skeleton");
50+
51+
// Verify a known outcome.
52+
const nose = result[0].pose.keypoints.find(
53+
(keypoint) => keypoint.part === "nose"
54+
);
55+
expect(nose).toBeTruthy();
56+
expect(nose.position.x).toBeCloseTo(448.6, 0);
57+
expect(nose.position.y).toBeCloseTo(255.9, 0);
58+
expect(nose.score).toBeCloseTo(0.999);
59+
});
60+
});

src/index.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import neuralNetwork from "./NeuralNetwork";
22
import handpose from "./Handpose";
3+
import poseDetection from "./PoseDetection";
34
import * as tf from "@tensorflow/tfjs";
45
import * as tfvis from "@tensorflow/tfjs-vis";
56
import p5Utils from "./utils/p5Utils";
@@ -11,5 +12,6 @@ export default Object.assign(
1112
tfvis,
1213
neuralNetwork,
1314
handpose,
15+
poseDetection,
1416
}
1517
);

0 commit comments

Comments
 (0)