33#include < cctype>
44#include < filesystem>
55#include < iostream>
6+ #include < memory>
7+ #include < mutex>
68#include < nlohmann/json.hpp>
9+ #include < nlohmann/json_fwd.hpp>
710
811#include " utility/Environment.hpp"
912#include " utility/Logging.hpp"
1215
1316#ifdef DEPTHAI_ENABLE_CURL
1417 #include < cpr/api.h>
18+ #include < cpr/cprtypes.h>
1519 #include < cpr/parameters.h>
1620 #include < cpr/status_codes.h>
1721#endif
@@ -96,8 +100,11 @@ class ZooManager {
96100
97101 /* *
98102 * @brief Download model from model zoo
103+ *
104+ * @param responseJson: JSON with download links
105+ * @param cprCallback: Progress callback
99106 */
100- void downloadModel (const nlohmann::json& responseJson);
107+ void downloadModel (const nlohmann::json& responseJson, std::unique_ptr<cpr::ProgressCallback> cprCallback );
101108
102109 /* *
103110 * @brief Return path to model in cache
@@ -328,7 +335,7 @@ nlohmann::json ZooManager::fetchModelDownloadLinks() {
328335 return responseJson;
329336}
330337
331- void ZooManager::downloadModel (const nlohmann::json& responseJson) {
338+ void ZooManager::downloadModel (const nlohmann::json& responseJson, std::unique_ptr<cpr::ProgressCallback> cprCallback ) {
332339 // Extract download links from response
333340 auto downloadLinks = responseJson[" download_links" ].get <std::vector<std::string>>();
334341 auto downloadHash = responseJson[" hash" ].get <std::string>();
@@ -346,7 +353,7 @@ void ZooManager::downloadModel(const nlohmann::json& responseJson) {
346353
347354 // Download all files and store them in cache folder
348355 for (const auto & downloadLink : downloadLinks) {
349- cpr::Response downloadResponse = cpr::Get (cpr::Url (downloadLink));
356+ cpr::Response downloadResponse = cpr::Get (cpr::Url (downloadLink), *cprCallback );
350357 if (checkIsErrorModelDownload (downloadResponse)) {
351358 removeModelCacheFolder ();
352359 throw std::runtime_error (generateErrorMessageModelDownload (downloadResponse));
@@ -395,7 +402,145 @@ std::string ZooManager::loadModelFromCache() const {
395402 return std::filesystem::absolute (folderFiles[0 ]).string ();
396403}
397404
398- std::string getModelFromZoo (const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
405+ class CprCallback {
406+ public:
407+ virtual ~CprCallback () = default ;
408+ CprCallback (const std::string& modelName) : modelName(modelName) {}
409+
410+ virtual void cprCallback (
411+ cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) = 0;
412+
413+ virtual std::unique_ptr<cpr::ProgressCallback> getCprProgressCallback () {
414+ return std::make_unique<cpr::ProgressCallback>(
415+ [this ](cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) {
416+ this ->cprCallback (downloadTotal, downloadNow, uploadTotal, uploadNow, userdata);
417+ return true ;
418+ });
419+ }
420+
421+ protected:
422+ std::string modelName;
423+ };
424+
425+ class JsonCprCallback : public CprCallback {
426+ constexpr static long long PRINT_INTERVAL_MS = 100 ;
427+
428+ public:
429+ JsonCprCallback (const std::string& modelName) : CprCallback(modelName) {
430+ startTime = std::chrono::steady_clock::time_point::min ();
431+ }
432+
433+ void print (long downloadTotal, long downloadNow, const std::string& modelName) {
434+ nlohmann::json json = {
435+ {" download_total" , downloadTotal},
436+ {" download_now" , downloadNow},
437+ {" model_name" , modelName},
438+ };
439+ std::cout << json.dump () << std::endl;
440+ }
441+
442+ void cprCallback (
443+ cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) override {
444+ (void )uploadTotal;
445+ (void )uploadNow;
446+ (void )userdata;
447+
448+ bool firstCall = startTime == std::chrono::steady_clock::time_point::min ();
449+ if (firstCall || downloadTotal == 0 ) {
450+ startTime = std::chrono::steady_clock::now ();
451+ }
452+
453+ bool shouldPrint = std::chrono::steady_clock::now () - startTime > std::chrono::milliseconds (PRINT_INTERVAL_MS) || this ->downloadTotal != downloadTotal;
454+
455+ if (shouldPrint) {
456+ print (downloadTotal, downloadNow, modelName);
457+ startTime = std::chrono::steady_clock::now ();
458+ }
459+
460+ this ->downloadTotal = downloadTotal;
461+ this ->downloadNow = downloadNow;
462+ }
463+
464+ ~JsonCprCallback () override {
465+ if (downloadTotal != 0 ) {
466+ print (downloadTotal, downloadNow, modelName);
467+ }
468+ }
469+
470+ private:
471+ long downloadTotal = 0 ;
472+ long downloadNow = 0 ;
473+ std::chrono::steady_clock::time_point startTime;
474+ };
475+
476+ class PrettyCprCallback : public CprCallback {
477+ public:
478+ PrettyCprCallback (const std::string& modelName) : CprCallback(modelName), finalProgressPrinted(false ) {}
479+
480+ void cprCallback (
481+ cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) override {
482+ (void )uploadTotal;
483+ (void )uploadNow;
484+ (void )userdata;
485+
486+ if (finalProgressPrinted) return ;
487+
488+ if (downloadTotal > 0 ) {
489+ float progress = static_cast <float >(downloadNow) / downloadTotal;
490+ int barWidth = 50 ;
491+ int pos = static_cast <int >(barWidth * progress);
492+
493+ std::cout << " \r Downloading " << modelName << " [" ;
494+ for (int i = 0 ; i < barWidth; ++i) {
495+ if (i < pos)
496+ std::cout << " =" ;
497+ else if (i == pos)
498+ std::cout << " >" ;
499+ else
500+ std::cout << " " ;
501+ }
502+ std::cout << " ] " << std::fixed << std::setprecision (3 ) << progress * 100 .0f << " % " << downloadNow / 1024 .0f / 1024 .0f << " /"
503+ << downloadTotal / 1024 .0f / 1024 .0f << " MB" ;
504+
505+ if (downloadNow == downloadTotal) {
506+ std::cout << std::endl;
507+ finalProgressPrinted = true ;
508+ } else {
509+ std::cout << " \r " ;
510+ std::cout.flush ();
511+ }
512+ }
513+ }
514+
515+ private:
516+ bool finalProgressPrinted;
517+ };
518+
519+ class NoneCprCallback : public CprCallback {
520+ public:
521+ NoneCprCallback (const std::string& modelName) : CprCallback(modelName) {}
522+
523+ void cprCallback (cpr::cpr_off_t , cpr::cpr_off_t , cpr::cpr_off_t , cpr::cpr_off_t , intptr_t ) override {
524+ // Do nothing
525+ }
526+ };
527+
528+ std::unique_ptr<CprCallback> getCprCallback (const std::string& format, const std::string& name) {
529+ if (format == " json" ) {
530+ return std::make_unique<JsonCprCallback>(name);
531+ } else if (format == " pretty" ) {
532+ return std::make_unique<PrettyCprCallback>(name);
533+ } else if (format == " none" ) {
534+ return std::make_unique<NoneCprCallback>(name);
535+ }
536+ throw std::runtime_error (" Invalid format: " + format);
537+ }
538+
539+ std::string getModelFromZoo (const NNModelDescription& modelDescription,
540+ bool useCached,
541+ const std::string& cacheDirectory,
542+ const std::string& apiKey,
543+ const std::string& progressFormat) {
399544 // Check if model description is valid
400545 if (!modelDescription.check ()) throw std::runtime_error (" Invalid model description:\n " + modelDescription.toString ());
401546
@@ -466,9 +611,12 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription, bool use
466611 // Create cache folder
467612 zooManager.createCacheFolder ();
468613
614+ // Create download progress callback
615+ std::unique_ptr<CprCallback> cprCallback = getCprCallback (progressFormat, modelDescription.globalMetadataEntryName .size () > 0 ? modelDescription.globalMetadataEntryName : modelDescription.model );
616+
469617 // Download model
470618 logger::info (" Downloading model from model zoo" );
471- zooManager.downloadModel (responseJson);
619+ zooManager.downloadModel (responseJson, cprCallback-> getCprProgressCallback () );
472620
473621 // Store model as yaml in the cache folder
474622 std::string yamlPath = combinePaths (zooManager.getModelCacheFolderPath (cacheDirectory), " model.yaml" );
@@ -479,7 +627,7 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription, bool use
479627 return modelPath;
480628}
481629
482- bool downloadModelsFromZoo (const std::string& path, const std::string& cacheDirectory, const std::string& apiKey) {
630+ bool downloadModelsFromZoo (const std::string& path, const std::string& cacheDirectory, const std::string& apiKey, const std::string& progressFormat ) {
483631 logger::info (" Downloading models from zoo" );
484632 // Make sure 'path' exists
485633 if (!std::filesystem::exists (path)) throw std::runtime_error (" Path does not exist: " + path);
@@ -507,7 +655,7 @@ bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDire
507655 try {
508656 logger::info (" Downloading model [{} / {}]: {}" , i + 1 , models.size (), modelName);
509657 auto modelDescription = NNModelDescription::fromYamlFile (modelName, path);
510- getModelFromZoo (modelDescription, true , cacheDirectory, apiKey);
658+ getModelFromZoo (modelDescription, true , cacheDirectory, apiKey, progressFormat );
511659 logger::info (" Downloaded model [{} / {}]: {}" , i + 1 , models.size (), modelName);
512660 numSuccess++;
513661 } catch (const std::exception& e) {
@@ -546,18 +694,24 @@ std::string ZooManager::getGlobalMetadataFilePath() const {
546694
547695#else
548696
549- std::string getModelFromZoo (const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
697+ std::string getModelFromZoo (const NNModelDescription& modelDescription,
698+ bool useCached,
699+ const std::string& cacheDirectory,
700+ const std::string& apiKey,
701+ const std::string& progressFormat) {
550702 (void )modelDescription;
551703 (void )useCached;
552704 (void )cacheDirectory;
553705 (void )apiKey;
706+ (void )progressFormat;
554707 throw std::runtime_error (" getModelFromZoo requires libcurl to be enabled. Please recompile DepthAI with libcurl enabled." );
555708}
556709
557- bool downloadModelsFromZoo (const std::string& path, const std::string& cacheDirectory, const std::string& apiKey) {
710+ bool downloadModelsFromZoo (const std::string& path, const std::string& cacheDirectory, const std::string& apiKey, const std::string& progressFormat ) {
558711 (void )path;
559712 (void )cacheDirectory;
560713 (void )apiKey;
714+ (void )progressFormat;
561715 throw std::runtime_error (" downloadModelsFromZoo requires libcurl to be enabled. Please recompile DepthAI with libcurl enabled." );
562716}
563717
0 commit comments