diff --git a/src/Handpose/index.js b/src/Handpose/index.js index 7059e757..cc559b14 100644 --- a/src/Handpose/index.js +++ b/src/Handpose/index.js @@ -44,7 +44,7 @@ class Handpose extends EventEmitter { modelType: this.config?.modelType ?? "full", // use full version of the model by default solutionPath: "https://cdn.jsdelivr.net/npm/@mediapipe/hands", // fetch model from mediapipe server }; - + await tf.ready(); this.model = await handPoseDetection.createDetector(pipeline, modelConfig); this.modelReady = true; diff --git a/src/NeuralNetwork/index.js b/src/NeuralNetwork/index.js index 794bf99c..c5556edf 100644 --- a/src/NeuralNetwork/index.js +++ b/src/NeuralNetwork/index.js @@ -113,6 +113,7 @@ class DiyNeuralNetwork { * @param {*} callback */ init(callback) { + tf.setBackend("webgl"); // check if the a static model should be built based on the inputs and output properties if (this.options.noTraining === true) { this.createLayersNoTraining(); diff --git a/src/PoseDetection/index.js b/src/PoseDetection/index.js index 1e486597..c78bed62 100644 --- a/src/PoseDetection/index.js +++ b/src/PoseDetection/index.js @@ -78,10 +78,9 @@ class PoseDetection extends EventEmitter { bodyPoseDetection.movenet.modelType.MULTIPOSE_LIGHTNING; } // Load the detector model - await tf.setBackend("webgl"); + await tf.ready(); this.model = await bodyPoseDetection.createDetector(pipeline, modelConfig); this.modelReady = true; - if (this.video) { this.predict(); }