@@ -10,183 +10,117 @@ Ported from pose-detection at Tensorflow.js
10
10
11
11
import EventEmitter from "events" ;
12
12
import * as tf from "@tensorflow/tfjs" ;
13
- import * as posenet from "@tensorflow-models/posenet " ;
13
+ import * as bodyPoseDetection from "@tensorflow-models/pose-detection " ;
14
14
import callCallback from "../utils/callcallback" ;
15
15
import handleArguments from "../utils/handleArguments" ;
16
+ import { mediaReady } from "../utils/imageUtilities" ;
16
17
17
- const DEFAULTS = {
18
- architecture : "MobileNetV1" , // 'MobileNetV1', 'ResNet50'
19
- outputStride : 16 , // 8, 16, 32
20
- flipHorizontal : false , // true, false
21
- minConfidence : 0.5 ,
22
- maxPoseDetections : 5 , // any number > 1
23
- scoreThreshold : 0.5 ,
24
- nmsRadius : 20 , // any number > 0
25
- detectionType : "multiple" , // 'single'
26
- inputResolution : 256 , // or { width: 257, height: 200 }
27
- multiplier : 0.75 , // 1.01, 1.0, 0.75, or 0.50 -- only for MobileNet
28
- quantBytes : 2 , // 4, 2, 1
29
- modelUrl : null , // url path to model
30
- } ;
31
-
32
- class PoseNet extends EventEmitter {
18
+ class PoseDetection extends EventEmitter {
33
19
/**
34
20
* @typedef {Object } options
35
- * @property {string } architecture - default 'MobileNetV1',
36
- * @property {number } inputResolution - default 257,
37
- * @property {number } outputStride - default 16
38
- * @property {boolean } flipHorizontal - default false
39
- * @property {number } minConfidence - default 0.5
40
- * @property {number } maxPoseDetections - default 5
41
- * @property {number } scoreThreshold - default 0.5
42
- * @property {number } nmsRadius - default 20
43
- * @property {String } detectionType - default single
44
- * @property {number } nmsRadius - default 0.75,
45
- * @property {number } quantBytes - default 2,
46
- * @property {string } modelUrl - default null
21
+ * @property {string } modelType - Optional. specify what model variant to load from. Default: 'SINGLEPOSE_LIGHTNING'.
22
+ * @property {boolean } enableSmoothing - Optional. Whether to use temporal filter to smooth keypoints across frames. Default: true.
23
+ * @property {string } modelUrl - Optional. A string that specifies custom url of the model. Default to load from tf.hub.
24
+ * @property {number } minPoseScore - Optional. The minimum confidence score for a pose to be detected. Default: 0.25.
25
+ * @property {number } multiPoseMaxDimension - Optional. The target maximum dimension to use as the input to the multi-pose model. Must be a mutiple of 32. Default: 256.
26
+ * @property {boolean } enableTracking - Optional. Track each person across the frame with a unique ID. Default: true.
27
+ * @property {string } trackerType - Optional. Specify what type of tracker to use. Default: 'boundingBox'.
28
+ * @property {Object } trackerConfig - Optional. Specify tracker configurations. Use tf.js setting by default.
47
29
*/
30
+
48
31
/**
49
32
* Create a PoseNet model.
50
33
* @param {HTMLVideoElement || p5.Video } video - Optional. A HTML video element or a p5 video element.
51
34
* @param {options } options - Optional. An object describing a model accuracy and performance.
52
- * @param {String } detectionType - Optional. A String value to run 'single' or 'multiple' estimation.
53
35
* @param {function } callback Optional. A function to run once the model has been loaded.
54
36
* If no callback is provided, it will return a promise that will be resolved once the
55
37
* model has loaded.
56
38
*/
57
- constructor ( video , options , detectionType , callback ) {
39
+ constructor ( video , options , callback ) {
58
40
super ( ) ;
41
+
59
42
this . video = video ;
60
- /**
61
- * The type of detection. 'single' or 'multiple'
62
- * @type {String }
63
- * @public
64
- */
65
- this . modelUrl = options . modelUrl || null ;
66
- this . architecture = options . architecture || DEFAULTS . architecture ;
67
- this . detectionType =
68
- detectionType || options . detectionType || DEFAULTS . detectionType ;
69
- this . outputStride = options . outputStride || DEFAULTS . outputStride ;
70
- this . flipHorizontal = options . flipHorizontal || DEFAULTS . flipHorizontal ;
71
- this . scoreThreshold = options . scoreThreshold || DEFAULTS . scoreThreshold ;
72
- this . minConfidence = options . minConfidence || DEFAULTS . minConfidence ;
73
- this . maxPoseDetections =
74
- options . maxPoseDetections || DEFAULTS . maxPoseDetections ;
75
- this . multiplier = options . multiplier || DEFAULTS . multiplier ;
76
- this . inputResolution = options . inputResolution || DEFAULTS . inputResolution ;
77
- this . quantBytes = options . quantBytes || DEFAULTS . quantBytes ;
78
- this . nmsRadius = options . nmsRadius || DEFAULTS . nmsRadius ;
79
- this . ready = callCallback ( this . load ( ) , callback ) ;
80
- // this.then = this.ready.then;
43
+ this . model = null ;
44
+ this . modelReady = false ;
45
+ this . config = options ;
46
+
47
+ this . ready = callCallback ( this . loadModel ( ) , callback ) ;
81
48
}
82
49
83
- async load ( ) {
84
- let modelJson ;
85
- if ( this . architecture . toLowerCase ( ) === "mobilenetv1" ) {
86
- modelJson = {
87
- architecture : this . architecture ,
88
- outputStride : this . outputStride ,
89
- inputResolution : this . inputResolution ,
90
- multiplier : this . multiplier ,
91
- quantBytes : this . quantBytes ,
92
- modelUrl : this . modelUrl ,
93
- } ;
94
- } else {
95
- modelJson = {
96
- architecture : this . architecture ,
97
- outputStride : this . outputStride ,
98
- inputResolution : this . inputResolution ,
99
- quantBytes : this . quantBytes ,
100
- } ;
50
+ /**
51
+ * Load the model and set it to this.model
52
+ * @return {this } the detector model.
53
+ */
54
+ async loadModel ( ) {
55
+ const pipeline = bodyPoseDetection . SupportedModels . MoveNet ;
56
+ //Set the config to user defined or default values
57
+ const modelConfig = {
58
+ enableSmoothing : this . config . enableSmoothing ?? true ,
59
+ modelUrl : this . config . modelUrl ,
60
+ minPoseScore : this . config . minPoseScore ?? 0.25 ,
61
+ multiPoseMaxDimension : this . config . multiPoseMaxDimension ?? 256 ,
62
+ enableTracking : this . config . enableTracking ?? true ,
63
+ trackerType : this . config . trackerType ?? "boundingBox" ,
64
+ trackerConfig : this . config . trackerConfig ,
65
+ } ;
66
+ switch ( this . config . modelType ) {
67
+ case "SINGLEPOSE_LIGHTNING" :
68
+ modelConfig . modelType =
69
+ bodyPoseDetection . movenet . modelType . SINGLEPOSE_LIGHTNING ;
70
+ break ;
71
+ case "SINGLEPOSE_THUNDER" :
72
+ modelConfig . modelType =
73
+ bodyPoseDetection . movenet . modelType . SINGLEPOSE_THUNDER ;
74
+ case "MULTIPOSE_LIGHTNING" :
75
+ modelConfig . modelType =
76
+ bodyPoseDetection . movenet . modelType . MULTIPOSE_LIGHTNING ;
101
77
}
102
-
103
- this . net = await posenet . load ( modelJson ) ;
78
+ // Load the detector model
79
+ await tf . setBackend ( "webgl" ) ;
80
+ this . model = await bodyPoseDetection . createDetector ( pipeline , modelConfig ) ;
81
+ this . modelReady = true ;
104
82
105
83
if ( this . video ) {
106
- if ( this . video . readyState === 0 ) {
107
- await new Promise ( ( resolve ) => {
108
- this . video . onloadeddata = ( ) => resolve ( ) ;
109
- } ) ;
110
- }
111
- if ( this . detectionType === "single" ) {
112
- this . singlePose ( ) ;
113
- } else {
114
- this . multiPose ( ) ;
115
- }
84
+ this . predict ( ) ;
116
85
}
117
- return this ;
118
- }
119
86
120
- skeleton ( keypoints , confidence = this . minConfidence ) {
121
- return posenet . getAdjacentKeyPoints ( keypoints , confidence ) ;
87
+ return this ;
122
88
}
123
89
124
- // eslint-disable-next-line class-methods-use-this
125
- mapParts ( pose ) {
126
- const newPose = JSON . parse ( JSON . stringify ( pose ) ) ;
127
- newPose . keypoints . forEach ( ( keypoint ) => {
128
- newPose [ keypoint . part ] = {
129
- x : keypoint . position . x ,
130
- y : keypoint . position . y ,
131
- confidence : keypoint . score ,
132
- } ;
133
- } ) ;
134
- return newPose ;
135
- }
90
+ //Add named keypoints to a MoveNet pose object
91
+ // mapParts(pose) {
92
+ // const newPose = JSON.parse(JSON.stringify(pose));
93
+ // newPose.keypoints.forEach((keypoint) => {
94
+ // newPose[keypoint.part] = {
95
+ // x: keypoint.position.x,
96
+ // y: keypoint.position.y,
97
+ // confidence: keypoint.score,
98
+ // };
99
+ // });
100
+ // return newPose;
101
+ // }
136
102
137
103
/**
138
104
* Given an image or video, returns an array of objects containing pose estimations
139
- * using single or multi-pose detection.
140
- * @param {HTMLVideoElement || p5.Video || function } inputOr
141
- * @param {function } cb
105
+ * @param {HTMLVideoElement || p5.Video || function } inputOr - An HMTL or p5.js image, video, or canvas element to run the prediction on.
106
+ * @param {function } cb - A callback function to handle the predictions.
142
107
*/
143
- async singlePose ( inputOr , cb ) {
144
- const { image : input , callback } = handleArguments ( this . video , inputOr , cb ) ;
145
-
146
- const pose = await this . net . estimateSinglePose ( input , {
147
- flipHorizontal : this . flipHorizontal ,
148
- } ) ;
149
- const poseWithParts = this . mapParts ( pose ) ;
150
- const result = [
151
- { pose : poseWithParts , skeleton : this . skeleton ( pose . keypoints ) } ,
152
- ] ;
153
- this . emit ( "pose" , result ) ;
154
-
155
- if ( this . video ) {
156
- return tf . nextFrame ( ) . then ( ( ) => this . singlePose ( ) ) ;
157
- }
158
-
159
- if ( typeof callback === "function" ) {
160
- callback ( result ) ;
108
+ async predict ( inputOr , cb ) {
109
+ const { image, callback } = handleArguments ( this . video , inputOr , cb ) ;
110
+ if ( ! image ) {
111
+ throw new Error ( "No input image found." ) ;
161
112
}
113
+ // If video is provided, wait for video to be loaded
114
+ await mediaReady ( image , false ) ;
115
+ const result = await this . model . estimatePoses ( image ) ;
162
116
163
- return result ;
164
- }
117
+ //Add named keypoints to each pose object
118
+ //const result = poses.map((pose) => this.mapParts(pose));
165
119
166
- /**
167
- * Given an image or video, returns an array of objects containing pose
168
- * estimations using single or multi-pose detection.
169
- * @param {HTMLVideoElement || p5.Video || function } inputOr
170
- * @param {function } cb
171
- */
172
- async multiPose ( inputOr , cb ) {
173
- const { image : input , callback } = handleArguments ( this . video , inputOr , cb ) ;
174
-
175
- const poses = await this . net . estimateMultiplePoses ( input , {
176
- flipHorizontal : this . flipHorizontal ,
177
- maxDetections : this . maxPoseDetections ,
178
- scoreThreshold : this . scoreThreshold ,
179
- nmsRadius : this . nmsRadius ,
180
- } ) ;
181
-
182
- const posesWithParts = poses . map ( ( pose ) => this . mapParts ( pose ) ) ;
183
- const result = posesWithParts . map ( ( pose ) => ( {
184
- pose,
185
- skeleton : this . skeleton ( pose . keypoints ) ,
186
- } ) ) ;
187
120
this . emit ( "pose" , result ) ;
121
+
188
122
if ( this . video ) {
189
- return tf . nextFrame ( ) . then ( ( ) => this . multiPose ( ) ) ;
123
+ return tf . nextFrame ( ) . then ( ( ) => this . predict ( ) ) ;
190
124
}
191
125
192
126
if ( typeof callback === "function" ) {
@@ -198,13 +132,8 @@ class PoseNet extends EventEmitter {
198
132
}
199
133
200
134
const poseDetection = ( ...inputs ) => {
201
- const {
202
- video,
203
- options = { } ,
204
- callback,
205
- string : detectionType ,
206
- } = handleArguments ( ...inputs ) ;
207
- return new PoseNet ( video , options , detectionType , callback ) ;
135
+ const { video, options = { } , callback } = handleArguments ( ...inputs ) ;
136
+ return new PoseDetection ( video , options , callback ) ;
208
137
} ;
209
138
210
139
export default poseDetection ;
0 commit comments