33#include < cctype>
44#include < filesystem>
55#include < iostream>
6+ #include < memory>
7+ #include < thread>
8+ #include < atomic>
9+ #include < mutex>
610#include < nlohmann/json.hpp>
11+ #include < nlohmann/json_fwd.hpp>
712
813#include " utility/Environment.hpp"
914#include " utility/Logging.hpp"
1217
1318#ifdef DEPTHAI_ENABLE_CURL
1419 #include < cpr/api.h>
20+ #include < cpr/cprtypes.h>
1521 #include < cpr/parameters.h>
1622 #include < cpr/status_codes.h>
1723#endif
@@ -96,8 +102,11 @@ class ZooManager {
96102
97103 /* *
98104 * @brief Download model from model zoo
105+ *
106+ * @param responseJson: JSON with download links
107+ * @param cprCallback: Progress callback
99108 */
100- void downloadModel (const nlohmann::json& responseJson);
109+ void downloadModel (const nlohmann::json& responseJson, std::unique_ptr<cpr::ProgressCallback> cprCallback );
101110
102111 /* *
103112 * @brief Return path to model in cache
@@ -328,7 +337,7 @@ nlohmann::json ZooManager::fetchModelDownloadLinks() {
328337 return responseJson;
329338}
330339
331- void ZooManager::downloadModel (const nlohmann::json& responseJson) {
340+ void ZooManager::downloadModel (const nlohmann::json& responseJson, std::unique_ptr<cpr::ProgressCallback> cprCallback ) {
332341 // Extract download links from response
333342 auto downloadLinks = responseJson[" download_links" ].get <std::vector<std::string>>();
334343 auto downloadHash = responseJson[" hash" ].get <std::string>();
@@ -346,7 +355,7 @@ void ZooManager::downloadModel(const nlohmann::json& responseJson) {
346355
347356 // Download all files and store them in cache folder
348357 for (const auto & downloadLink : downloadLinks) {
349- cpr::Response downloadResponse = cpr::Get (cpr::Url (downloadLink));
358+ cpr::Response downloadResponse = cpr::Get (cpr::Url (downloadLink), *cprCallback );
350359 if (checkIsErrorModelDownload (downloadResponse)) {
351360 removeModelCacheFolder ();
352361 throw std::runtime_error (generateErrorMessageModelDownload (downloadResponse));
@@ -395,7 +404,11 @@ std::string ZooManager::loadModelFromCache() const {
395404 return std::filesystem::absolute (folderFiles[0 ]).string ();
396405}
397406
398- std::string getModelFromZoo (const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
407+ std::string getModelFromZoo (const NNModelDescription& modelDescription,
408+ bool useCached,
409+ const std::string& cacheDirectory,
410+ const std::string& apiKey,
411+ const std::function<void (cpr::cpr_off_t , cpr::cpr_off_t , cpr::cpr_off_t , cpr::cpr_off_t , intptr_t )>& progressCallback) {
399412 // Check if model description is valid
400413 if (!modelDescription.check ()) throw std::runtime_error (" Invalid model description:\n " + modelDescription.toString ());
401414
@@ -466,9 +479,21 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription, bool use
466479 // Create cache folder
467480 zooManager.createCacheFolder ();
468481
482+ // Create download progress callback
483+ std::unique_ptr<cpr::ProgressCallback> cprCallback;
484+ if (progressCallback) {
485+ cprCallback = std::make_unique<cpr::ProgressCallback>(
486+ [&](cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) {
487+ progressCallback (downloadTotal, downloadNow, uploadTotal, uploadNow, userdata);
488+ return true ;
489+ });
490+ } else {
491+ cprCallback = std::make_unique<cpr::ProgressCallback>([](cpr::cpr_off_t , cpr::cpr_off_t , cpr::cpr_off_t , cpr::cpr_off_t , intptr_t ) { return true ; });
492+ }
493+
469494 // Download model
470495 logger::info (" Downloading model from model zoo" );
471- zooManager.downloadModel (responseJson);
496+ zooManager.downloadModel (responseJson, std::move (cprCallback) );
472497
473498 // Store model as yaml in the cache folder
474499 std::string yamlPath = combinePaths (zooManager.getModelCacheFolderPath (cacheDirectory), " model.yaml" );
@@ -479,6 +504,84 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription, bool use
479504 return modelPath;
480505}
481506
507+ std::string getModelFromZoo (const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
508+ return getModelFromZoo (modelDescription, useCached, cacheDirectory, apiKey, nullptr );
509+ }
510+
511+ struct JsonDownloadProgressManager {
512+ JsonDownloadProgressManager (size_t updateIntervalMs = 1000 ) : updateIntervalMs(updateIntervalMs) {
513+ pause ();
514+ }
515+
516+ size_t updateIntervalMs;
517+
518+ std::string model;
519+ size_t bytesDownloaded = 0 ;
520+ size_t bytesTotal = 0 ;
521+ std::mutex mutex;
522+
523+ std::thread thread;
524+ std::atomic<bool > running;
525+ std::atomic<bool > started;
526+ std::atomic<bool > firstUpdate;
527+
528+ bool pause () {
529+ std::lock_guard<std::mutex> lock (mutex);
530+ started = false ;
531+ firstUpdate = true ;
532+ return true ;
533+ }
534+
535+ bool resume () {
536+ std::lock_guard<std::mutex> lock (mutex);
537+ started = true ;
538+ firstUpdate = true ;
539+ return true ;
540+ }
541+
542+ void update (cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) {
543+ (void )uploadTotal;
544+ (void )uploadNow;
545+ (void )userdata;
546+ std::lock_guard<std::mutex> lock (mutex);
547+ this ->bytesDownloaded = downloadNow;
548+ this ->bytesTotal = downloadTotal;
549+ firstUpdate = false ;
550+ }
551+
552+ nlohmann::json getJson () {
553+ std::lock_guard<std::mutex> lock (mutex);
554+ nlohmann::json json;
555+ json[" model" ] = model;
556+ json[" bytes_downloaded" ] = bytesDownloaded;
557+ json[" bytes_total" ] = bytesTotal;
558+ return json;
559+ }
560+
561+ void setModel (const std::string& model) {
562+ std::lock_guard<std::mutex> lock (mutex);
563+ this ->model = model;
564+ }
565+
566+ void startThread () {
567+ running = true ;
568+ thread = std::thread ([this ]() {
569+ while (running) {
570+ std::this_thread::sleep_for (std::chrono::milliseconds (updateIntervalMs));
571+ if (!started || firstUpdate) continue ;
572+ auto json = getJson ();
573+ std::string jsonStr = json.dump ();
574+ std::cout << jsonStr << std::endl;
575+ }
576+ });
577+ }
578+
579+ void stopThread () {
580+ running = false ;
581+ thread.join ();
582+ }
583+ };
584+
482585bool downloadModelsFromZoo (const std::string& path, const std::string& cacheDirectory, const std::string& apiKey) {
483586 logger::info (" Downloading models from zoo" );
484587 // Make sure 'path' exists
@@ -497,6 +600,9 @@ bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDire
497600 }
498601 }
499602
603+ auto jsonDownloadProgressManager = std::make_unique<JsonDownloadProgressManager>(100 );
604+ jsonDownloadProgressManager->startThread ();
605+ auto callback = std::bind (&JsonDownloadProgressManager::update, jsonDownloadProgressManager.get (), std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5);
500606 // Download models from yaml files
501607 int numSuccess = 0 , numFail = 0 ;
502608 for (size_t i = 0 ; i < models.size (); ++i) {
@@ -505,17 +611,21 @@ bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDire
505611
506612 // Download model - ignore the returned model path here == we are only interested in downloading the model
507613 try {
614+ jsonDownloadProgressManager->setModel (modelName);
615+ jsonDownloadProgressManager->resume ();
508616 logger::info (" Downloading model [{} / {}]: {}" , i + 1 , models.size (), modelName);
509617 auto modelDescription = NNModelDescription::fromYamlFile (modelName, path);
510- getModelFromZoo (modelDescription, true , cacheDirectory, apiKey);
618+ getModelFromZoo (modelDescription, true , cacheDirectory, apiKey, callback );
511619 logger::info (" Downloaded model [{} / {}]: {}" , i + 1 , models.size (), modelName);
512620 numSuccess++;
513621 } catch (const std::exception& e) {
514622 logger::error (" Failed to download model [{} / {}]: {} in folder {}\n {}" , i + 1 , models.size (), modelName, path, e.what ());
515623 numFail++;
516624 }
625+ jsonDownloadProgressManager->pause ();
517626 }
518627
628+ jsonDownloadProgressManager->stopThread ();
519629 logger::info (" Downloaded {} models from folder {} | {} failed." , numSuccess, path, numFail);
520630 return numFail == 0 ;
521631}
0 commit comments