From 4d3fdb0674bd1ee46b418cff20aceb09fb92f5fa Mon Sep 17 00:00:00 2001 From: Linda Paiste Date: Thu, 22 Feb 2024 21:50:20 -0600 Subject: [PATCH] Extract common logic for detectStart and detectStop into a helper class. --- src/BodyPose/index.js | 108 +++-------------------- src/BodySegmentation/index.js | 111 ++--------------------- src/FaceMesh/index.js | 107 +++------------------- src/HandPose/index.js | 104 ++-------------------- src/ImageDetector/index.js | 162 ++++++++++++++++++++++++++++++++++ 5 files changed, 203 insertions(+), 389 deletions(-) create mode 100644 src/ImageDetector/index.js diff --git a/src/BodyPose/index.js b/src/BodyPose/index.js index f8130834..e3b4bd73 100644 --- a/src/BodyPose/index.js +++ b/src/BodyPose/index.js @@ -8,8 +8,9 @@ BodyPose Ported from pose-detection at Tensorflow.js */ -import * as tf from "@tensorflow/tfjs"; import * as poseDetection from "@tensorflow-models/pose-detection"; +import * as tf from "@tensorflow/tfjs"; +import ImageDetector from "../ImageDetector"; import callCallback from "../utils/callcallback"; import handleArguments from "../utils/handleArguments"; import { mediaReady } from "../utils/imageUtilities"; @@ -55,13 +56,6 @@ class BodyPose { this.model = null; this.config = options; this.runtimeConfig = {}; - this.detectMedia = null; - this.detectCallback = null; - - // flags for detectStart() and detectStop() - this.detecting = false; // true when detection loop is running - this.signalStop = false; // Signal to stop the loop - this.prevCall = ""; // Track previous call to detectStart() or detectStop() this.ready = callCallback(this.loadModel(), callback); } @@ -129,105 +123,23 @@ class BodyPose { this.model = await poseDetection.createDetector(pipeline, modelConfig); // for compatibility with p5's preload() - if (this.p5PreLoadExists) window._decrementPreload(); + if (this.p5PreLoadExists()) window._decrementPreload(); return this; } - /** - * A callback function that handles the pose detection results. - * @callback gotPoses - * @param {Array} results - An array of objects containing poses. - */ - /** * Asynchronously outputs a single pose prediction result when called. - * @param {*} media - An HMTL or p5.js image, video, or canvas element to run the prediction on. - * @param {gotPoses} callback - A callback function to handle the predictions. + * @param {*} media - An HTML or p5.js image, video, or canvas element to run the prediction on. * @returns {Promise} an array of poses. */ - async detect(...inputs) { - //Parse out the input parameters - const argumentObject = handleArguments(...inputs); - argumentObject.require( - "image", - "An html or p5.js image, video, or canvas element argument is required for detect()." - ); - const { image, callback } = argumentObject; - - await mediaReady(image, false); + async detect(media) { + await mediaReady(media, false); const predictions = await this.model.estimatePoses( - image, + media, this.runtimeConfig ); - let result = predictions; - result = this.addKeypoints(result); - if (typeof callback === "function") callback(result); - return result; - } - - /** - * Repeatedly outputs pose predictions through a callback function. - * Calls the internal detectLoop() function. - * @param {*} media - An HMTL or p5.js image, video, or canvas element to run the prediction on. - * @param {gotPoses} callback - A callback function to handle the predictions. - * @returns {Promise} an array of predictions. - */ - detectStart(...inputs) { - // Parse out the input parameters - const argumentObject = handleArguments(...inputs); - argumentObject.require( - "image", - "An html or p5.js image, video, or canvas element argument is required for detectStart()." - ); - argumentObject.require( - "callback", - "A callback function argument is required for detectStart()." - ); - this.detectMedia = argumentObject.image; - this.detectCallback = argumentObject.callback; - - this.signalStop = false; - if (!this.detecting) { - this.detecting = true; - this.detectLoop(); - } - if (this.prevCall === "start") { - console.warn( - "detectStart() was called more than once without calling detectStop(). The lastest detectStart() call will be used and the previous calls will be ignored." - ); - } - this.prevCall = "start"; - } - - /** - * Internal function that calls estimatePoses in a loop - * Can be started by detectStart() and terminated by detectStop() - * @private - */ - async detectLoop() { - await mediaReady(this.detectMedia, false); - while (!this.signalStop) { - const predictions = await this.model.estimatePoses( - this.detectMedia, - this.runtimeConfig - ); - let result = predictions; - result = this.addKeypoints(result); - this.detectCallback(result); - // wait for the frame to update - await tf.nextFrame(); - } - this.detecting = false; - this.signalStop = false; - } - - /** - * Stops the detection loop before next detection loop runs. - */ - detectStop() { - if (this.detecting) this.signalStop = true; - this.prevCall = "stop"; + return this.addKeypoints(predictions); } /** @@ -268,12 +180,12 @@ class BodyPose { /** * Factory function that returns a BodyPose instance. - * @returns {BodyPose} A BodyPose instance. + * @returns {ImageDetector} A BodyPose instance. */ const bodyPose = (...inputs) => { const { string, options = {}, callback } = handleArguments(...inputs); const instance = new BodyPose(string, options, callback); - return instance; + return new ImageDetector(instance); }; export default bodyPose; diff --git a/src/BodySegmentation/index.js b/src/BodySegmentation/index.js index 43754850..9257010b 100644 --- a/src/BodySegmentation/index.js +++ b/src/BodySegmentation/index.js @@ -10,6 +10,7 @@ import * as tf from "@tensorflow/tfjs"; import * as tfBodySegmentation from "@tensorflow-models/body-segmentation"; +import ImageDetector from "../ImageDetector"; import callCallback from "../utils/callcallback"; import handleArguments from "../utils/handleArguments"; import BODYPIX_PALETTE from "./BODYPIX_PALETTE"; @@ -18,7 +19,7 @@ import { mediaReady } from "../utils/imageUtilities"; class BodySegmentation { /** * Create BodyPix. - * @param {HTMLVideoElement} [video] - An HTMLVideoElement. + * @param {string} modelName * @param {object} [options] - An object with options. * @param {function} [callback] - A callback to be called when the model is ready. */ @@ -27,12 +28,10 @@ class BodySegmentation { if (this.p5PreLoadExists()) window._incrementPreload(); this.modelName = modelName; - this.video = video; this.model = null; this.config = options; this.runtimeConfig = {}; - this.detectMedia = null; - this.detectCallback = null; + this.ready = callCallback(this.loadModel(), callback); } @@ -117,22 +116,17 @@ class BodySegmentation { ); // for compatibility with p5's preload() - if (this.p5PreLoadExists) window._decrementPreload(); + if (this.p5PreLoadExists()) window._decrementPreload(); return this; } + /** * Calls segmentPeople in a loop. * Can be started by detectStart() and terminated by detectStop(). * @private */ - async detect(...inputs) { - const argumentObject = handleArguments(...inputs); - argumentObject.require( - "image", - "An html or p5.js image, video, or canvas element argument is required for detectStart()." - ); - const { image, callback } = argumentObject; + async detect(image) { await mediaReady(image, false); @@ -165,99 +159,12 @@ class BodySegmentation { } result.mask = this.generateP5Image(result.maskImageData); - if (callback) callback(result); return result; } - /** - * Repeatedly outputs hand predictions through a callback function. - * @param {*} [media] - An HMTL or p5.js image, video, or canvas element to run the prediction on. - * @param {gotHands} [callback] - A callback to handle the hand detection results. - */ - detectStart(...inputs) { - // Parse out the input parameters - const argumentObject = handleArguments(...inputs); - argumentObject.require( - "image", - "An html or p5.js image, video, or canvas element argument is required for detectStart()." - ); - argumentObject.require( - "callback", - "A callback function argument is required for detectStart()." - ); - this.detectMedia = argumentObject.image; - this.detectCallback = argumentObject.callback; - - this.signalStop = false; - if (!this.detecting) { - this.detecting = true; - this.detectLoop(); - } - if (this.prevCall === "start") { - console.warn( - "detectStart() was called more than once without calling detectStop(). Only the latest detectStart() call will take effect." - ); - } - this.prevCall = "start"; - } - - /** - * Stops the detection loop before next detection loop runs. - */ - detectStop() { - if (this.detecting) this.signalStop = true; - this.prevCall = "stop"; - } - - /** - * Calls segmentPeople in a loop. - * Can be started by detectStart() and terminated by detectStop(). - * @private - */ - async detectLoop() { - await mediaReady(this.detectMedia, false); - while (!this.signalStop) { - const segmentation = await this.model.segmentPeople( - this.detectMedia, - this.runtimeConfig - ); - - const result = {}; - switch (this.runtimeConfig.maskType) { - case "background": - result.maskImageData = await tfBodySegmentation.toBinaryMask( - segmentation, - { r: 0, g: 0, b: 0, a: 255 }, - { r: 0, g: 0, b: 0, a: 0 } - ); - break; - case "person": - result.maskImageData = await tfBodySegmentation.toBinaryMask( - segmentation - ); - break; - case "parts": - result.maskImageData = await tfBodySegmentation.toColoredMask( - segmentation, - tfBodySegmentation.bodyPixMaskValueToRainbowColor, - { r: 255, g: 255, b: 255, a: 255 } - ); - result.bodyParts = BODYPIX_PALETTE; - } - result.mask = this.generateP5Image(result.maskImageData); - - this.detectCallback(result); - await tf.nextFrame(); - } - - this.detecting = false; - this.signalStop = false; - } /** * Generate a p5 image from the image data - * @param imageData - a ImageData object - * @param width - the width of the p5 image - * @param height - the height of the p5 image + * @param {ImageData} imageData - a ImageData object * @return a p5.Image object */ generateP5Image(imageData) { @@ -288,12 +195,12 @@ class BodySegmentation { /** * Factory function that returns a Facemesh instance - * @returns {Object} A new bodySegmentation instance + * @returns {ImageDetector} A new bodySegmentation instance */ const bodySegmentation = (...inputs) => { const { string, options = {}, callback } = handleArguments(...inputs); const instance = new BodySegmentation(string, options, callback); - return instance; + return new ImageDetector(instance); }; export default bodySegmentation; diff --git a/src/FaceMesh/index.js b/src/FaceMesh/index.js index 467fabcf..3a3e1d15 100644 --- a/src/FaceMesh/index.js +++ b/src/FaceMesh/index.js @@ -10,16 +10,19 @@ import * as tf from "@tensorflow/tfjs"; import * as faceLandmarksDetection from "@tensorflow-models/face-landmarks-detection"; +import ImageDetector from "../ImageDetector"; import callCallback from "../utils/callcallback"; import handleArguments from "../utils/handleArguments"; -import { mediaReady } from "../utils/imageUtilities"; +/** + * @implements {SpecificDetectorImplementation} + */ class FaceMesh { /** * An options object to configure FaceMesh settings * @typedef {Object} configOptions - * @property {number} maxFacess - The maximum number of faces to detect. Defaults to 2. - * @property {boolean} refineLandmarks - Refine the ladmarks. Defaults to false. + * @property {number} maxFaces - The maximum number of faces to detect. Defaults to 2. + * @property {boolean} refineLandmarks - Refine the landmarks. Defaults to false. * @property {boolean} flipHorizontal - Flip the result horizontally. Defaults to false. * @property {string} runtime - The runtime to use. "mediapipe"(default) or "tfjs". * @@ -41,13 +44,6 @@ class FaceMesh { this.model = null; this.config = options; this.runtimeConfig = {}; - this.detectMedia = null; - this.detectCallback = null; - - // flags for detectStart() and detectStop() - this.detecting = false; // true when detection loop is running - this.signalStop = false; // true when detectStop() is called and detecting is true - this.prevCall = ""; // "start" or "stop", used for giving warning messages with detectStart() is called twice in a row this.ready = callCallback(this.loadModel(), callback); } @@ -78,99 +74,22 @@ class FaceMesh { ); // for compatibility with p5's preload() - if (this.p5PreLoadExists) window._decrementPreload(); + if (this.p5PreLoadExists()) window._decrementPreload(); return this; } /** * Asynchronously output a single face prediction result when called - * @param {*} [media] - An HMTL or p5.js image, video, or canvas element to run the prediction on. - * @param {function} [callback] - A callback function to handle the predictions. + * @param {*} [media] - An HTML or p5.js image, video, or canvas element to run the prediction on. * @returns {Promise} an array of predictions. */ - async detect(...inputs) { - // Parse out the input parameters - const argumentObject = handleArguments(...inputs); - argumentObject.require( - "image", - "An html or p5.js image, video, or canvas element argument is required for detect()." - ); - const { image, callback } = argumentObject; - - await mediaReady(image, false); + async detect(media) { const predictions = await this.model.estimateFaces( - image, + media, this.runtimeConfig ); - let result = predictions; - result = this.addKeypoints(result); - if (typeof callback === "function") callback(result); - return result; - } - - /** - * Repeatedly output face predictions through a callback function - * @param {*} [media] - An HMTL or p5.js image, video, or canvas element to run the prediction on. - * @param {function} [callback] - A callback function to handle the predictions. - * @returns {Promise} an array of predictions. - */ - detectStart(...inputs) { - // Parse out the input parameters - const argumentObject = handleArguments(...inputs); - argumentObject.require( - "image", - "An html or p5.js image, video, or canvas element argument is required for detectStart()." - ); - argumentObject.require( - "callback", - "A callback function argument is required for detectStart()." - ); - this.detectMedia = argumentObject.image; - this.detectCallback = argumentObject.callback; - - this.signalStop = false; - if (!this.detecting) { - this.detecting = true; - this.detectLoop(); - } - if (this.prevCall === "start") { - console.warn( - "detectStart() was called more than once without calling detectStop(). The lastest detectStart() call will be used and the previous calls will be ignored." - ); - } - this.prevCall = "start"; - } - - /** - * Stop the detection loop before next detection loop runs. - */ - detectStop() { - if (this.detecting) this.signalStop = true; - this.prevCall = "stop"; - } - - /** - * Internal function to call estimateFaces in a loop - * Can be started by detectStart() and terminated by detectStop() - * - * @private - */ - async detectLoop() { - await mediaReady(this.detectMedia, false); - while (!this.signalStop) { - const predictions = await this.model.estimateFaces( - this.detectMedia, - this.runtimeConfig - ); - let result = predictions; - result = this.addKeypoints(result); - this.detectCallback(result); - // wait for the frame to update - await tf.nextFrame(); - } - this.detecting = false; - this.signalStop = false; + return this.addKeypoints(predictions); } /** @@ -265,12 +184,12 @@ class FaceMesh { /** * Factory function that returns a FaceMesh instance - * @returns {Object} A new faceMesh instance + * @returns {ImageDetector} A new faceMesh instance */ const faceMesh = (...inputs) => { const { options = {}, callback } = handleArguments(...inputs); const instance = new FaceMesh(options, callback); - return instance; + return new ImageDetector(instance); }; export default faceMesh; diff --git a/src/HandPose/index.js b/src/HandPose/index.js index 388d98bf..9a4badb0 100644 --- a/src/HandPose/index.js +++ b/src/HandPose/index.js @@ -10,6 +10,7 @@ import * as tf from "@tensorflow/tfjs"; import * as handPoseDetection from "@tensorflow-models/hand-pose-detection"; +import ImageDetector from "../ImageDetector"; import callCallback from "../utils/callcallback"; import handleArguments from "../utils/handleArguments"; import { mediaReady } from "../utils/imageUtilities"; @@ -42,13 +43,6 @@ class HandPose { this.model = null; this.config = options; this.runtimeConfig = {}; - this.detectMedia = null; - this.detectCallback = null; - - // flags for detectStart() and detectStop() - this.detecting = false; // True when detection loop is running - this.signalStop = false; // Signal to stop the loop - this.prevCall = ""; // Track previous call to detectStart() or detectStop() this.ready = callCallback(this.loadModel(), callback); } @@ -79,104 +73,24 @@ class HandPose { this.model = await handPoseDetection.createDetector(pipeline, modelConfig); // for compatibility with p5's preload() - if (this.p5PreLoadExists) window._decrementPreload(); + if (this.p5PreLoadExists()) window._decrementPreload(); return this; } - /** - * A callback function that handles the handPose detection results. - * @callback gotHands - * @param {Array} results - The detection output. - */ - /** * Asynchronously outputs a single hand landmark detection result when called. * Supports both callback and promise. - * @param {*} [media] - An HMTL or p5.js image, video, or canvas element to run the detection on. - * @param {gotHands} [callback] - Optional. A callback to handle the hand detection result. + * @param {*} [media] - An HTML or p5.js image, video, or canvas element to run the detection on. * @returns {Promise} The detection result. */ - async detect(...inputs) { - //Parse out the input parameters - const argumentObject = handleArguments(...inputs); - argumentObject.require( - "image", - "An html or p5.js image, video, or canvas element argument is required for detect()." - ); - const { image, callback } = argumentObject; - - await mediaReady(image, false); + async detect(media) { + await mediaReady(media, false); const predictions = await this.model.estimateHands( - image, + media, this.runtimeConfig ); - let result = predictions; - result = this.addKeypoints(result); - if (typeof callback === "function") callback(result); - return result; - } - - /** - * Repeatedly outputs hand predictions through a callback function. - * @param {*} [media] - An HMTL or p5.js image, video, or canvas element to run the prediction on. - * @param {gotHands} [callback] - A callback to handle the hand detection results. - */ - detectStart(...inputs) { - // Parse out the input parameters - const argumentObject = handleArguments(...inputs); - argumentObject.require( - "image", - "An html or p5.js image, video, or canvas element argument is required for detectStart()." - ); - argumentObject.require( - "callback", - "A callback function argument is required for detectStart()." - ); - this.detectMedia = argumentObject.image; - this.detectCallback = argumentObject.callback; - - this.signalStop = false; - if (!this.detecting) { - this.detecting = true; - this.detectLoop(); - } - if (this.prevCall === "start") { - console.warn( - "detectStart() was called more than once without calling detectStop(). Only the latest detectStart() call will take effect." - ); - } - this.prevCall = "start"; - } - - /** - * Stops the detection loop before next detection loop runs. - */ - detectStop() { - if (this.detecting) this.signalStop = true; - this.prevCall = "stop"; - } - - /** - * Calls estimateHands in a loop. - * Can be started by detectStart() and terminated by detectStop(). - * @private - */ - async detectLoop() { - await mediaReady(this.detectMedia, false); - while (!this.signalStop) { - const predictions = await this.model.estimateHands( - this.detectMedia, - this.runtimeConfig - ); - let result = predictions; - result = this.addKeypoints(result); - this.detectCallback(result); - // wait for the frame to update - await tf.nextFrame(); - } - this.detecting = false; - this.signalStop = false; + return this.addKeypoints(predictions); } /** @@ -221,12 +135,12 @@ class HandPose { /** * Factory function that returns a new HandPose instance. - * @returns {HandPose} A new handPose instance. + * @returns {ImageDetector} A new handPose instance. */ const handPose = (...inputs) => { const { options = {}, callback } = handleArguments(...inputs); const instance = new HandPose(options, callback); - return instance; + return new ImageDetector(instance); }; export default handPose; diff --git a/src/ImageDetector/index.js b/src/ImageDetector/index.js new file mode 100644 index 00000000..4af37d76 --- /dev/null +++ b/src/ImageDetector/index.js @@ -0,0 +1,162 @@ +import * as tf from '@tensorflow/tfjs'; +import handleArguments from '../utils/handleArguments'; +import { mediaReady } from '../utils/imageUtilities'; + +/** + * @typedef {Object} SpecificDetectorImplementation + * + * @property {Promise} ready - lets the parent detector know that the + * specific implementation is ready. + * + * @property {(input: tf.Tensor3D) => Promise} detect - core detection method. + * the ImageDetector will call the `detect` method of the specific implementation. + * It should accept a TensorFlow tensor?? or an image?? + * And return an array of detections?? + */ + +/** + * Helper class for handling the public API of detector models (facemesh, etc.) + * Exposes the public methods for single and continuous detection. + * Executes the detection using whatever model is passed to the constructor. + */ +export default class ImageDetector { + + /** + * @param {SpecificDetectorImplementation} specificImplementation + */ + constructor(specificImplementation) { + /** + * @type {SpecificDetectorImplementation} + */ + this.implementation = specificImplementation; + + /** + * @type {Promise} + * TODO: do we need to handle onReady callbacks? + */ + this.ready = this.implementation.ready; + + /** + * @type {InputImage | null} + * The video or image used for continuous detection. + */ + this.detectMedia = null; + /** + * @type {function | null} + * Function to call with the results of each detection. + */ + this.detectCallback = null; + + // flags for detectStart() and detectStop() + /** + * @type {boolean} + * true when detection loop is running + */ + this.detecting = false; + /** + * @type {boolean} + * Signal to stop the loop + */ + this.signalStop = false; + /** + * @type {"start" | "stop" | ""} + * Track previous call to detectStart() or detectStop(), + * used for giving warning messages when detectStart() is called twice in a row. + */ + this.prevCall = ""; + } + + /** + * Repeatedly output detections through a callback function + * @param {*} media - An HTML or p5.js image, video, or canvas element to run the detection on. + * @param {function} callback - A callback function to handle each detection. + * @void + */ + detectStart(...inputs) { + // Parse out the input parameters + const argumentObject = handleArguments(...inputs); + argumentObject.require( + "image", + "An html or p5.js image, video, or canvas element argument is required for detectStart()." + ); + argumentObject.require( + "callback", + "A callback function argument is required for detectStart()." + ); + this.detectMedia = argumentObject.image; + this.detectCallback = argumentObject.callback; + + this.signalStop = false; + if (!this.detecting) { + this.detecting = true; + this.detectLoop(); + } + if (this.prevCall === "start") { + console.warn( + "detectStart() was called more than once without calling detectStop(). The latest detectStart() call will be used and the previous calls will be ignored." + ); + } + this.prevCall = "start"; + } + + /** + * Stop the detection loop before next detection runs. + */ + detectStop() { + if (this.detecting) this.signalStop = true; + this.prevCall = "stop"; + } + + /** + * Internal function to call the detect method repeatedly in a loop + * Can be started by detectStart() and terminated by detectStop() + * + * @private + */ + async detectLoop() { + // Make sure that both the model and the media are loaded before beginning. + await Promise.all([ + this.ready, + mediaReady(this.detectMedia, false) + ]); + + // Continuous detection loop. + while (!this.signalStop) { + const result = await this.implementation.detect(this.detectMedia); + this.detectCallback(result); + // wait for the frame to update + await tf.nextFrame(); + } + + // Update flags when done. + this.detecting = false; + this.signalStop = false; + } + + /** + * Asynchronously output a single detection result when called + * @param {*} media - An HTML or p5.js image, video, or canvas element to run the detection on. + * @param {function} [callback] - A callback function to handle the detection results. + * @returns {Promise} an array of detections. + */ + async detect(...inputs) { + // Parse out the input parameters + const argumentObject = handleArguments(...inputs); + argumentObject.require( + "image", + "An html or p5.js image, video, or canvas element argument is required for detect()." + ); + const { image, callback } = argumentObject; + + // Make sure that both the model and the media are loaded before beginning. + await Promise.all([ + this.ready, + mediaReady(this.detectMedia, false) + ]); + + const result = await this.implementation.detect(image); + if (typeof callback === "function") callback(result); + return result; + } + +}