44#include < filesystem>
55#include < iostream>
66#include < memory>
7- #include < thread>
8- #include < atomic>
97#include < mutex>
108#include < nlohmann/json.hpp>
119#include < nlohmann/json_fwd.hpp>
@@ -332,8 +330,6 @@ nlohmann::json ZooManager::fetchModelDownloadLinks() {
332330 throw std::runtime_error (generateErrorMessageHub (response));
333331 }
334332
335- std::cout << " Response: " << response.text << std::endl;
336-
337333 // Extract download links from response
338334 nlohmann::json responseJson = nlohmann::json::parse (response.text );
339335 return responseJson;
@@ -406,11 +402,145 @@ std::string ZooManager::loadModelFromCache() const {
406402 return std::filesystem::absolute (folderFiles[0 ]).string ();
407403}
408404
405+ class CprCallback {
406+ public:
407+ virtual ~CprCallback () = default ;
408+ CprCallback (const std::string& modelName) : modelName(modelName) {}
409+
410+ virtual void cprCallback (
411+ cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) = 0;
412+
413+ virtual std::unique_ptr<cpr::ProgressCallback> getCprProgressCallback () {
414+ return std::make_unique<cpr::ProgressCallback>(
415+ [this ](cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) {
416+ this ->cprCallback (downloadTotal, downloadNow, uploadTotal, uploadNow, userdata);
417+ return true ;
418+ });
419+ }
420+
421+ protected:
422+ std::string modelName;
423+ };
424+
425+ class JsonCprCallback : public CprCallback {
426+ constexpr static long long PRINT_INTERVAL_MS = 100 ;
427+
428+ public:
429+ JsonCprCallback (const std::string& modelName) : CprCallback(modelName) {
430+ startTime = std::chrono::steady_clock::time_point::min ();
431+ }
432+
433+ void print (long downloadTotal, long downloadNow, const std::string& modelName) {
434+ nlohmann::json json = {
435+ {" download_total" , downloadTotal},
436+ {" download_now" , downloadNow},
437+ {" model_name" , modelName},
438+ };
439+ std::cout << json.dump () << std::endl;
440+ }
441+
442+ void cprCallback (
443+ cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) override {
444+ (void )uploadTotal;
445+ (void )uploadNow;
446+ (void )userdata;
447+
448+ bool firstCall = startTime == std::chrono::steady_clock::time_point::min ();
449+ if (firstCall || downloadTotal == 0 ) {
450+ startTime = std::chrono::steady_clock::now ();
451+ }
452+
453+ bool shouldPrint = std::chrono::steady_clock::now () - startTime > std::chrono::milliseconds (PRINT_INTERVAL_MS) || this ->downloadTotal != downloadTotal;
454+
455+ if (shouldPrint) {
456+ print (downloadTotal, downloadNow, modelName);
457+ startTime = std::chrono::steady_clock::now ();
458+ }
459+
460+ this ->downloadTotal = downloadTotal;
461+ this ->downloadNow = downloadNow;
462+ }
463+
464+ ~JsonCprCallback () override {
465+ if (downloadTotal != 0 ) {
466+ print (downloadTotal, downloadNow, modelName);
467+ }
468+ }
469+
470+ private:
471+ long downloadTotal = 0 ;
472+ long downloadNow = 0 ;
473+ std::chrono::steady_clock::time_point startTime;
474+ };
475+
476+ class PrettyCprCallback : public CprCallback {
477+ public:
478+ PrettyCprCallback (const std::string& modelName) : CprCallback(modelName), finalProgressPrinted(false ) {}
479+
480+ void cprCallback (
481+ cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) override {
482+ (void )uploadTotal;
483+ (void )uploadNow;
484+ (void )userdata;
485+
486+ if (finalProgressPrinted) return ;
487+
488+ if (downloadTotal > 0 ) {
489+ float progress = static_cast <float >(downloadNow) / downloadTotal;
490+ int barWidth = 50 ;
491+ int pos = static_cast <int >(barWidth * progress);
492+
493+ std::cout << " \r Downloading " << modelName << " [" ;
494+ for (int i = 0 ; i < barWidth; ++i) {
495+ if (i < pos)
496+ std::cout << " =" ;
497+ else if (i == pos)
498+ std::cout << " >" ;
499+ else
500+ std::cout << " " ;
501+ }
502+ std::cout << " ] " << std::fixed << std::setprecision (3 ) << progress * 100 .0f << " % " << downloadNow / 1024 .0f / 1024 .0f << " /"
503+ << downloadTotal / 1024 .0f / 1024 .0f << " MB" ;
504+
505+ if (downloadNow == downloadTotal) {
506+ std::cout << std::endl;
507+ finalProgressPrinted = true ;
508+ } else {
509+ std::cout << " \r " ;
510+ std::cout.flush ();
511+ }
512+ }
513+ }
514+
515+ private:
516+ bool finalProgressPrinted;
517+ };
518+
519+ class NoneCprCallback : public CprCallback {
520+ public:
521+ NoneCprCallback (const std::string& modelName) : CprCallback(modelName) {}
522+
523+ void cprCallback (cpr::cpr_off_t , cpr::cpr_off_t , cpr::cpr_off_t , cpr::cpr_off_t , intptr_t ) override {
524+ // Do nothing
525+ }
526+ };
527+
528+ std::unique_ptr<CprCallback> getCprCallback (const std::string& format, const std::string& name) {
529+ if (format == " json" ) {
530+ return std::make_unique<JsonCprCallback>(name);
531+ } else if (format == " pretty" ) {
532+ return std::make_unique<PrettyCprCallback>(name);
533+ } else if (format == " none" ) {
534+ return std::make_unique<NoneCprCallback>(name);
535+ }
536+ throw std::runtime_error (" Invalid format: " + format);
537+ }
538+
409539std::string getModelFromZoo (const NNModelDescription& modelDescription,
410540 bool useCached,
411541 const std::string& cacheDirectory,
412542 const std::string& apiKey,
413- const std::function< void (cpr:: cpr_off_t , cpr:: cpr_off_t , cpr:: cpr_off_t , cpr:: cpr_off_t , intptr_t )>& progressCallback ) {
543+ const std::string& progressFormat ) {
414544 // Check if model description is valid
415545 if (!modelDescription.check ()) throw std::runtime_error (" Invalid model description:\n " + modelDescription.toString ());
416546
@@ -482,20 +612,11 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription,
482612 zooManager.createCacheFolder ();
483613
484614 // Create download progress callback
485- std::unique_ptr<cpr::ProgressCallback> cprCallback;
486- if (progressCallback) {
487- cprCallback = std::make_unique<cpr::ProgressCallback>(
488- [&](cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) {
489- progressCallback (downloadTotal, downloadNow, uploadTotal, uploadNow, userdata);
490- return true ;
491- });
492- } else {
493- cprCallback = std::make_unique<cpr::ProgressCallback>([](cpr::cpr_off_t , cpr::cpr_off_t , cpr::cpr_off_t , cpr::cpr_off_t , intptr_t ) { return true ; });
494- }
615+ std::unique_ptr<CprCallback> cprCallback = getCprCallback (progressFormat, modelDescription.globalMetadataEntryName .size () > 0 ? modelDescription.globalMetadataEntryName : modelDescription.model );
495616
496617 // Download model
497618 logger::info (" Downloading model from model zoo" );
498- zooManager.downloadModel (responseJson, std::move ( cprCallback));
619+ zooManager.downloadModel (responseJson, cprCallback-> getCprProgressCallback ( ));
499620
500621 // Store model as yaml in the cache folder
501622 std::string yamlPath = combinePaths (zooManager.getModelCacheFolderPath (cacheDirectory), " model.yaml" );
@@ -506,85 +627,7 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription,
506627 return modelPath;
507628}
508629
509- std::string getModelFromZoo (const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
510- return getModelFromZoo (modelDescription, useCached, cacheDirectory, apiKey, nullptr );
511- }
512-
513- struct JsonDownloadProgressManager {
514- JsonDownloadProgressManager (size_t updateIntervalMs = 1000 ) : updateIntervalMs(updateIntervalMs) {
515- pause ();
516- }
517-
518- size_t updateIntervalMs;
519-
520- std::string model;
521- size_t bytesDownloaded = 0 ;
522- size_t bytesTotal = 0 ;
523- std::mutex mutex;
524-
525- std::thread thread;
526- std::atomic<bool > running;
527- std::atomic<bool > started;
528- std::atomic<bool > firstUpdate;
529-
530- bool pause () {
531- std::lock_guard<std::mutex> lock (mutex);
532- started = false ;
533- firstUpdate = true ;
534- return true ;
535- }
536-
537- bool resume () {
538- std::lock_guard<std::mutex> lock (mutex);
539- started = true ;
540- firstUpdate = true ;
541- return true ;
542- }
543-
544- void update (cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) {
545- (void )uploadTotal;
546- (void )uploadNow;
547- (void )userdata;
548- std::lock_guard<std::mutex> lock (mutex);
549- this ->bytesDownloaded = downloadNow;
550- this ->bytesTotal = downloadTotal;
551- firstUpdate = false ;
552- }
553-
554- nlohmann::json getJson () {
555- std::lock_guard<std::mutex> lock (mutex);
556- nlohmann::json json;
557- json[" model" ] = model;
558- json[" bytes_downloaded" ] = bytesDownloaded;
559- json[" bytes_total" ] = bytesTotal;
560- return json;
561- }
562-
563- void setModel (const std::string& model) {
564- std::lock_guard<std::mutex> lock (mutex);
565- this ->model = model;
566- }
567-
568- void startThread () {
569- running = true ;
570- thread = std::thread ([this ]() {
571- while (running) {
572- std::this_thread::sleep_for (std::chrono::milliseconds (updateIntervalMs));
573- if (!started || firstUpdate) continue ;
574- auto json = getJson ();
575- std::string jsonStr = json.dump ();
576- std::cout << jsonStr << std::endl;
577- }
578- });
579- }
580-
581- void stopThread () {
582- running = false ;
583- thread.join ();
584- }
585- };
586-
587- bool downloadModelsFromZoo (const std::string& path, const std::string& cacheDirectory, const std::string& apiKey, const std::string& format) {
630+ bool downloadModelsFromZoo (const std::string& path, const std::string& cacheDirectory, const std::string& apiKey, const std::string& progressFormat) {
588631 logger::info (" Downloading models from zoo" );
589632 // Make sure 'path' exists
590633 if (!std::filesystem::exists (path)) throw std::runtime_error (" Path does not exist: " + path);
@@ -602,9 +645,6 @@ bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDire
602645 }
603646 }
604647
605- auto jsonDownloadProgressManager = std::make_unique<JsonDownloadProgressManager>(100 );
606- jsonDownloadProgressManager->startThread ();
607- auto callback = std::bind (&JsonDownloadProgressManager::update, jsonDownloadProgressManager.get (), std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5);
608648 // Download models from yaml files
609649 int numSuccess = 0 , numFail = 0 ;
610650 for (size_t i = 0 ; i < models.size (); ++i) {
@@ -613,21 +653,17 @@ bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDire
613653
614654 // Download model - ignore the returned model path here == we are only interested in downloading the model
615655 try {
616- jsonDownloadProgressManager->setModel (modelName);
617- jsonDownloadProgressManager->resume ();
618656 logger::info (" Downloading model [{} / {}]: {}" , i + 1 , models.size (), modelName);
619657 auto modelDescription = NNModelDescription::fromYamlFile (modelName, path);
620- getModelFromZoo (modelDescription, true , cacheDirectory, apiKey, callback );
658+ getModelFromZoo (modelDescription, true , cacheDirectory, apiKey, progressFormat );
621659 logger::info (" Downloaded model [{} / {}]: {}" , i + 1 , models.size (), modelName);
622660 numSuccess++;
623661 } catch (const std::exception& e) {
624662 logger::error (" Failed to download model [{} / {}]: {} in folder {}\n {}" , i + 1 , models.size (), modelName, path, e.what ());
625663 numFail++;
626664 }
627- jsonDownloadProgressManager->pause ();
628665 }
629666
630- jsonDownloadProgressManager->stopThread ();
631667 logger::info (" Downloaded {} models from folder {} | {} failed." , numSuccess, path, numFail);
632668 return numFail == 0 ;
633669}
@@ -658,18 +694,24 @@ std::string ZooManager::getGlobalMetadataFilePath() const {
658694
659695#else
660696
661- std::string getModelFromZoo (const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
697+ std::string getModelFromZoo (const NNModelDescription& modelDescription,
698+ bool useCached,
699+ const std::string& cacheDirectory,
700+ const std::string& apiKey,
701+ const std::string& progressFormat) {
662702 (void )modelDescription;
663703 (void )useCached;
664704 (void )cacheDirectory;
665705 (void )apiKey;
706+ (void )progressFormat;
666707 throw std::runtime_error (" getModelFromZoo requires libcurl to be enabled. Please recompile DepthAI with libcurl enabled." );
667708}
668709
669- bool downloadModelsFromZoo (const std::string& path, const std::string& cacheDirectory, const std::string& apiKey) {
710+ bool downloadModelsFromZoo (const std::string& path, const std::string& cacheDirectory, const std::string& apiKey, const std::string& progressFormat ) {
670711 (void )path;
671712 (void )cacheDirectory;
672713 (void )apiKey;
714+ (void )progressFormat;
673715 throw std::runtime_error (" downloadModelsFromZoo requires libcurl to be enabled. Please recompile DepthAI with libcurl enabled." );
674716}
675717
0 commit comments