|
| 1 | +// Copyright (c) 2018 ml5 |
| 2 | +// |
| 3 | +// This software is released under the MIT License. |
| 4 | +// https://opensource.org/licenses/MIT |
| 5 | + |
| 6 | +/* |
| 7 | +PoseDetection |
| 8 | +Ported from pose-detection at Tensorflow.js |
| 9 | +*/ |
| 10 | + |
| 11 | +import EventEmitter from "events"; |
| 12 | +import * as tf from "@tensorflow/tfjs"; |
| 13 | +import * as bodyPoseDetection from "@tensorflow-models/pose-detection"; |
| 14 | +import callCallback from "../utils/callcallback"; |
| 15 | +import handleArguments from "../utils/handleArguments"; |
| 16 | +import { mediaReady } from "../utils/imageUtilities"; |
| 17 | + |
| 18 | +class PoseDetection extends EventEmitter { |
| 19 | + /** |
| 20 | + * @typedef {Object} options |
| 21 | + * @property {string} modelType - Optional. specify what model variant to load from. Default: 'MULTIPOSE_LIGHTNING'. |
| 22 | + * @property {boolean} enableSmoothing - Optional. Whether to use temporal filter to smooth keypoints across frames. Default: true. |
| 23 | + * @property {string} modelUrl - Optional. A string that specifies custom url of the model. Default to load from tf.hub. |
| 24 | + * @property {number} minPoseScore - Optional. The minimum confidence score for a pose to be detected. Default: 0.25. |
| 25 | + * @property {number} multiPoseMaxDimension - Optional. The target maximum dimension to use as the input to the multi-pose model. Must be a mutiple of 32. Default: 256. |
| 26 | + * @property {boolean} enableTracking - Optional. Track each person across the frame with a unique ID. Default: true. |
| 27 | + * @property {string} trackerType - Optional. Specify what type of tracker to use. Default: 'boundingBox'. |
| 28 | + * @property {Object} trackerConfig - Optional. Specify tracker configurations. Use tf.js setting by default. |
| 29 | + */ |
| 30 | + |
| 31 | + /** |
| 32 | + * Create a PoseNet model. |
| 33 | + * @param {HTMLVideoElement || p5.Video} video - Optional. A HTML video element or a p5 video element. |
| 34 | + * @param {options} options - Optional. An object describing a model accuracy and performance. |
| 35 | + * @param {function} callback Optional. A function to run once the model has been loaded. |
| 36 | + * If no callback is provided, it will return a promise that will be resolved once the |
| 37 | + * model has loaded. |
| 38 | + */ |
| 39 | + constructor(video, options, callback) { |
| 40 | + super(); |
| 41 | + |
| 42 | + this.video = video; |
| 43 | + this.model = null; |
| 44 | + this.modelReady = false; |
| 45 | + this.config = options; |
| 46 | + |
| 47 | + this.ready = callCallback(this.loadModel(), callback); |
| 48 | + } |
| 49 | + |
| 50 | + /** |
| 51 | + * Load the model and set it to this.model |
| 52 | + * @return {this} the detector model. |
| 53 | + */ |
| 54 | + async loadModel() { |
| 55 | + const pipeline = bodyPoseDetection.SupportedModels.MoveNet; |
| 56 | + //Set the config to user defined or default values |
| 57 | + const modelConfig = { |
| 58 | + enableSmoothing: this.config.enableSmoothing ?? true, |
| 59 | + modelUrl: this.config.modelUrl, |
| 60 | + minPoseScore: this.config.minPoseScore ?? 0.25, |
| 61 | + multiPoseMaxDimension: this.config.multiPoseMaxDimension ?? 256, |
| 62 | + enableTracking: this.config.enableTracking ?? true, |
| 63 | + trackerType: this.config.trackerType ?? "boundingBox", |
| 64 | + trackerConfig: this.config.trackerConfig, |
| 65 | + }; |
| 66 | + // use multi-pose lightning model by default |
| 67 | + switch (this.config.modelType) { |
| 68 | + case "SINGLEPOSE_LIGHTNING": |
| 69 | + modelConfig.modelType = |
| 70 | + bodyPoseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING; |
| 71 | + break; |
| 72 | + case "SINGLEPOSE_THUNDER": |
| 73 | + modelConfig.modelType = |
| 74 | + bodyPoseDetection.movenet.modelType.SINGLEPOSE_THUNDER; |
| 75 | + break; |
| 76 | + default: |
| 77 | + modelConfig.modelType = |
| 78 | + bodyPoseDetection.movenet.modelType.MULTIPOSE_LIGHTNING; |
| 79 | + } |
| 80 | + // Load the detector model |
| 81 | + await tf.setBackend("webgl"); |
| 82 | + this.model = await bodyPoseDetection.createDetector(pipeline, modelConfig); |
| 83 | + this.modelReady = true; |
| 84 | + |
| 85 | + if (this.video) { |
| 86 | + this.predict(); |
| 87 | + } |
| 88 | + |
| 89 | + return this; |
| 90 | + } |
| 91 | + |
| 92 | + //TODO: Add named keypoints to a MoveNet pose object |
| 93 | + |
| 94 | + /** |
| 95 | + * Given an image or video, returns an array of objects containing pose estimations |
| 96 | + * @param {HTMLVideoElement || p5.Video || function} inputOr - An HMTL or p5.js image, video, or canvas element to run the prediction on. |
| 97 | + * @param {function} cb - A callback function to handle the predictions. |
| 98 | + */ |
| 99 | + async predict(inputOr, cb) { |
| 100 | + const { image, callback } = handleArguments(this.video, inputOr, cb); |
| 101 | + if (!image) { |
| 102 | + throw new Error("No input image found."); |
| 103 | + } |
| 104 | + // If video is provided, wait for video to be loaded |
| 105 | + await mediaReady(image, false); |
| 106 | + const result = await this.model.estimatePoses(image); |
| 107 | + |
| 108 | + // TODO: Add named keypoints to each pose object |
| 109 | + |
| 110 | + this.emit("pose", result); |
| 111 | + |
| 112 | + if (this.video) { |
| 113 | + return tf.nextFrame().then(() => this.predict()); |
| 114 | + } |
| 115 | + |
| 116 | + if (typeof callback === "function") { |
| 117 | + callback(result); |
| 118 | + } |
| 119 | + |
| 120 | + return result; |
| 121 | + } |
| 122 | +} |
| 123 | + |
| 124 | +const poseDetection = (...inputs) => { |
| 125 | + const { video, options = {}, callback } = handleArguments(...inputs); |
| 126 | + return new PoseDetection(video, options, callback); |
| 127 | +}; |
| 128 | + |
| 129 | +export default poseDetection; |
0 commit comments