Skip to content

Commit b8a0723

Browse files
committed
add json, none and pretty download prints
1 parent 5b79c09 commit b8a0723

File tree

4 files changed

+155
-109
lines changed

4 files changed

+155
-109
lines changed

bindings/python/src/modelzoo/ZooBindings.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@ void ZooBindings::bind(pybind11::module& m, void* pCallstack) {
3232
py::arg("useCached") = true,
3333
py::arg("cacheDirectory") = "",
3434
py::arg("apiKey") = "",
35+
py::arg("progressFormat") = "none",
3536
DOC(dai, getModelFromZoo));
3637

3738
m.def("downloadModelsFromZoo",
3839
downloadModelsFromZoo,
3940
py::arg("path"),
4041
py::arg("cacheDirectory") = "",
4142
py::arg("apiKey") = "",
43+
py::arg("progressFormat") = "none",
4244
DOC(dai, downloadModelsFromZoo));
4345

4446
// Bind NNModelDescription

include/depthai/modelzoo/Zoo.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,14 @@ struct NNModelDescription {
8686
* DEPTHAI_ZOO_CACHE_PATH environment variable and uses that if set, otherwise the default value is used (see getDefaultCachePath).
8787
* @param apiKey: API key for the model zoo, default is "". If apiKey is set to "", this function checks the DEPTHAI_ZOO_API_KEY environment variable and uses
8888
* that if set. Otherwise, no API key is used.
89+
* @param progressFormat: Format to use for progress output (possible values: pretty, json, none), default is "none"
8990
* @return std::string: Path to the model in cache
9091
*/
9192
std::string getModelFromZoo(const NNModelDescription& modelDescription,
9293
bool useCached = true,
9394
const std::string& cacheDirectory = "",
94-
const std::string& apiKey = "");
95+
const std::string& apiKey = "",
96+
const std::string& progressFormat = "none");
9597

9698
/**
9799
* @brief Helper function allowing one to download all models specified in yaml files in the given path and store them in the cache directory
@@ -101,10 +103,10 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription,
101103
* DEPTHAI_ZOO_CACHE_PATH environment variable and uses that if set, otherwise the default is used (see getDefaultCachePath).
102104
* @param apiKey: API key for the model zoo, default is "". If apiKey is set to "", this function checks the DEPTHAI_ZOO_API_KEY environment variable and uses
103105
* that if set. Otherwise, no API key is used.
104-
* @param format: Format to use for output (possible values: pretty, json), default is "pretty"
106+
* @param progressFormat: Format to use for progress output (possible values: pretty, json, none), default is "none"
105107
* @return bool: True if all models were downloaded successfully, false otherwise
106108
*/
107-
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory = "", const std::string& apiKey = "", const std::string& format = "pretty");
109+
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory = "", const std::string& apiKey = "", const std::string& progressFormat = "none");
108110

109111
std::ostream& operator<<(std::ostream& os, const NNModelDescription& modelDescription);
110112

src/modelzoo/Zoo.cpp

Lines changed: 147 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
#include <filesystem>
55
#include <iostream>
66
#include <memory>
7-
#include <thread>
8-
#include <atomic>
97
#include <mutex>
108
#include <nlohmann/json.hpp>
119
#include <nlohmann/json_fwd.hpp>
@@ -332,8 +330,6 @@ nlohmann::json ZooManager::fetchModelDownloadLinks() {
332330
throw std::runtime_error(generateErrorMessageHub(response));
333331
}
334332

335-
std::cout << "Response: " << response.text << std::endl;
336-
337333
// Extract download links from response
338334
nlohmann::json responseJson = nlohmann::json::parse(response.text);
339335
return responseJson;
@@ -406,11 +402,145 @@ std::string ZooManager::loadModelFromCache() const {
406402
return std::filesystem::absolute(folderFiles[0]).string();
407403
}
408404

405+
class CprCallback {
406+
public:
407+
virtual ~CprCallback() = default;
408+
CprCallback(const std::string& modelName) : modelName(modelName) {}
409+
410+
virtual void cprCallback(
411+
cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) = 0;
412+
413+
virtual std::unique_ptr<cpr::ProgressCallback> getCprProgressCallback() {
414+
return std::make_unique<cpr::ProgressCallback>(
415+
[this](cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) {
416+
this->cprCallback(downloadTotal, downloadNow, uploadTotal, uploadNow, userdata);
417+
return true;
418+
});
419+
}
420+
421+
protected:
422+
std::string modelName;
423+
};
424+
425+
class JsonCprCallback : public CprCallback {
426+
constexpr static long long PRINT_INTERVAL_MS = 100;
427+
428+
public:
429+
JsonCprCallback(const std::string& modelName) : CprCallback(modelName) {
430+
startTime = std::chrono::steady_clock::time_point::min();
431+
}
432+
433+
void print(long downloadTotal, long downloadNow, const std::string& modelName) {
434+
nlohmann::json json = {
435+
{"download_total", downloadTotal},
436+
{"download_now", downloadNow},
437+
{"model_name", modelName},
438+
};
439+
std::cout << json.dump() << std::endl;
440+
}
441+
442+
void cprCallback(
443+
cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) override {
444+
(void)uploadTotal;
445+
(void)uploadNow;
446+
(void)userdata;
447+
448+
bool firstCall = startTime == std::chrono::steady_clock::time_point::min();
449+
if(firstCall || downloadTotal == 0) {
450+
startTime = std::chrono::steady_clock::now();
451+
}
452+
453+
bool shouldPrint = std::chrono::steady_clock::now() - startTime > std::chrono::milliseconds(PRINT_INTERVAL_MS) || this->downloadTotal != downloadTotal;
454+
455+
if(shouldPrint) {
456+
print(downloadTotal, downloadNow, modelName);
457+
startTime = std::chrono::steady_clock::now();
458+
}
459+
460+
this->downloadTotal = downloadTotal;
461+
this->downloadNow = downloadNow;
462+
}
463+
464+
~JsonCprCallback() override {
465+
if(downloadTotal != 0) {
466+
print(downloadTotal, downloadNow, modelName);
467+
}
468+
}
469+
470+
private:
471+
long downloadTotal = 0;
472+
long downloadNow = 0;
473+
std::chrono::steady_clock::time_point startTime;
474+
};
475+
476+
class PrettyCprCallback : public CprCallback {
477+
public:
478+
PrettyCprCallback(const std::string& modelName) : CprCallback(modelName), finalProgressPrinted(false) {}
479+
480+
void cprCallback(
481+
cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) override {
482+
(void)uploadTotal;
483+
(void)uploadNow;
484+
(void)userdata;
485+
486+
if(finalProgressPrinted) return;
487+
488+
if(downloadTotal > 0) {
489+
float progress = static_cast<float>(downloadNow) / downloadTotal;
490+
int barWidth = 50;
491+
int pos = static_cast<int>(barWidth * progress);
492+
493+
std::cout << "\rDownloading " << modelName << " [";
494+
for(int i = 0; i < barWidth; ++i) {
495+
if(i < pos)
496+
std::cout << "=";
497+
else if(i == pos)
498+
std::cout << ">";
499+
else
500+
std::cout << " ";
501+
}
502+
std::cout << "] " << std::fixed << std::setprecision(3) << progress * 100.0f << "% " << downloadNow / 1024.0f / 1024.0f << "/"
503+
<< downloadTotal / 1024.0f / 1024.0f << " MB";
504+
505+
if(downloadNow == downloadTotal) {
506+
std::cout << std::endl;
507+
finalProgressPrinted = true;
508+
} else {
509+
std::cout << "\r";
510+
std::cout.flush();
511+
}
512+
}
513+
}
514+
515+
private:
516+
bool finalProgressPrinted;
517+
};
518+
519+
class NoneCprCallback : public CprCallback {
520+
public:
521+
NoneCprCallback(const std::string& modelName) : CprCallback(modelName) {}
522+
523+
void cprCallback(cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, intptr_t) override {
524+
// Do nothing
525+
}
526+
};
527+
528+
std::unique_ptr<CprCallback> getCprCallback(const std::string& format, const std::string& name) {
529+
if(format == "json") {
530+
return std::make_unique<JsonCprCallback>(name);
531+
} else if(format == "pretty") {
532+
return std::make_unique<PrettyCprCallback>(name);
533+
} else if(format == "none") {
534+
return std::make_unique<NoneCprCallback>(name);
535+
}
536+
throw std::runtime_error("Invalid format: " + format);
537+
}
538+
409539
std::string getModelFromZoo(const NNModelDescription& modelDescription,
410540
bool useCached,
411541
const std::string& cacheDirectory,
412542
const std::string& apiKey,
413-
const std::function<void(cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, intptr_t)>& progressCallback) {
543+
const std::string& progressFormat) {
414544
// Check if model description is valid
415545
if(!modelDescription.check()) throw std::runtime_error("Invalid model description:\n" + modelDescription.toString());
416546

@@ -482,20 +612,11 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription,
482612
zooManager.createCacheFolder();
483613

484614
// Create download progress callback
485-
std::unique_ptr<cpr::ProgressCallback> cprCallback;
486-
if(progressCallback) {
487-
cprCallback = std::make_unique<cpr::ProgressCallback>(
488-
[&](cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) {
489-
progressCallback(downloadTotal, downloadNow, uploadTotal, uploadNow, userdata);
490-
return true;
491-
});
492-
} else {
493-
cprCallback = std::make_unique<cpr::ProgressCallback>([](cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, intptr_t) { return true; });
494-
}
615+
std::unique_ptr<CprCallback> cprCallback = getCprCallback(progressFormat, modelDescription.globalMetadataEntryName.size() > 0 ? modelDescription.globalMetadataEntryName : modelDescription.model);
495616

496617
// Download model
497618
logger::info("Downloading model from model zoo");
498-
zooManager.downloadModel(responseJson, std::move(cprCallback));
619+
zooManager.downloadModel(responseJson, cprCallback->getCprProgressCallback());
499620

500621
// Store model as yaml in the cache folder
501622
std::string yamlPath = combinePaths(zooManager.getModelCacheFolderPath(cacheDirectory), "model.yaml");
@@ -506,85 +627,7 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription,
506627
return modelPath;
507628
}
508629

509-
std::string getModelFromZoo(const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
510-
return getModelFromZoo(modelDescription, useCached, cacheDirectory, apiKey, nullptr);
511-
}
512-
513-
struct JsonDownloadProgressManager {
514-
JsonDownloadProgressManager(size_t updateIntervalMs = 1000) : updateIntervalMs(updateIntervalMs) {
515-
pause();
516-
}
517-
518-
size_t updateIntervalMs;
519-
520-
std::string model;
521-
size_t bytesDownloaded = 0;
522-
size_t bytesTotal = 0;
523-
std::mutex mutex;
524-
525-
std::thread thread;
526-
std::atomic<bool> running;
527-
std::atomic<bool> started;
528-
std::atomic<bool> firstUpdate;
529-
530-
bool pause() {
531-
std::lock_guard<std::mutex> lock(mutex);
532-
started = false;
533-
firstUpdate = true;
534-
return true;
535-
}
536-
537-
bool resume() {
538-
std::lock_guard<std::mutex> lock(mutex);
539-
started = true;
540-
firstUpdate = true;
541-
return true;
542-
}
543-
544-
void update(cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) {
545-
(void)uploadTotal;
546-
(void)uploadNow;
547-
(void)userdata;
548-
std::lock_guard<std::mutex> lock(mutex);
549-
this->bytesDownloaded = downloadNow;
550-
this->bytesTotal = downloadTotal;
551-
firstUpdate = false;
552-
}
553-
554-
nlohmann::json getJson() {
555-
std::lock_guard<std::mutex> lock(mutex);
556-
nlohmann::json json;
557-
json["model"] = model;
558-
json["bytes_downloaded"] = bytesDownloaded;
559-
json["bytes_total"] = bytesTotal;
560-
return json;
561-
}
562-
563-
void setModel(const std::string& model) {
564-
std::lock_guard<std::mutex> lock(mutex);
565-
this->model = model;
566-
}
567-
568-
void startThread() {
569-
running = true;
570-
thread = std::thread([this]() {
571-
while(running) {
572-
std::this_thread::sleep_for(std::chrono::milliseconds(updateIntervalMs));
573-
if(!started || firstUpdate) continue;
574-
auto json = getJson();
575-
std::string jsonStr = json.dump();
576-
std::cout << jsonStr << std::endl;
577-
}
578-
});
579-
}
580-
581-
void stopThread() {
582-
running = false;
583-
thread.join();
584-
}
585-
};
586-
587-
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey, const std::string& format) {
630+
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey, const std::string& progressFormat) {
588631
logger::info("Downloading models from zoo");
589632
// Make sure 'path' exists
590633
if(!std::filesystem::exists(path)) throw std::runtime_error("Path does not exist: " + path);
@@ -602,9 +645,6 @@ bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDire
602645
}
603646
}
604647

605-
auto jsonDownloadProgressManager = std::make_unique<JsonDownloadProgressManager>(100);
606-
jsonDownloadProgressManager->startThread();
607-
auto callback = std::bind(&JsonDownloadProgressManager::update, jsonDownloadProgressManager.get(), std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5);
608648
// Download models from yaml files
609649
int numSuccess = 0, numFail = 0;
610650
for(size_t i = 0; i < models.size(); ++i) {
@@ -613,21 +653,17 @@ bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDire
613653

614654
// Download model - ignore the returned model path here == we are only interested in downloading the model
615655
try {
616-
jsonDownloadProgressManager->setModel(modelName);
617-
jsonDownloadProgressManager->resume();
618656
logger::info("Downloading model [{} / {}]: {}", i + 1, models.size(), modelName);
619657
auto modelDescription = NNModelDescription::fromYamlFile(modelName, path);
620-
getModelFromZoo(modelDescription, true, cacheDirectory, apiKey, callback);
658+
getModelFromZoo(modelDescription, true, cacheDirectory, apiKey, progressFormat);
621659
logger::info("Downloaded model [{} / {}]: {}", i + 1, models.size(), modelName);
622660
numSuccess++;
623661
} catch(const std::exception& e) {
624662
logger::error("Failed to download model [{} / {}]: {} in folder {}\n{}", i + 1, models.size(), modelName, path, e.what());
625663
numFail++;
626664
}
627-
jsonDownloadProgressManager->pause();
628665
}
629666

630-
jsonDownloadProgressManager->stopThread();
631667
logger::info("Downloaded {} models from folder {} | {} failed.", numSuccess, path, numFail);
632668
return numFail == 0;
633669
}
@@ -658,18 +694,24 @@ std::string ZooManager::getGlobalMetadataFilePath() const {
658694

659695
#else
660696

661-
std::string getModelFromZoo(const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
697+
std::string getModelFromZoo(const NNModelDescription& modelDescription,
698+
bool useCached,
699+
const std::string& cacheDirectory,
700+
const std::string& apiKey,
701+
const std::string& progressFormat) {
662702
(void)modelDescription;
663703
(void)useCached;
664704
(void)cacheDirectory;
665705
(void)apiKey;
706+
(void)progressFormat;
666707
throw std::runtime_error("getModelFromZoo requires libcurl to be enabled. Please recompile DepthAI with libcurl enabled.");
667708
}
668709

669-
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey) {
710+
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey, const std::string& progressFormat) {
670711
(void)path;
671712
(void)cacheDirectory;
672713
(void)apiKey;
714+
(void)progressFormat;
673715
throw std::runtime_error("downloadModelsFromZoo requires libcurl to be enabled. Please recompile DepthAI with libcurl enabled.");
674716
}
675717

src/modelzoo/zoo_helper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ int main(int argc, char* argv[]) {
7474
}
7575

7676
// Download models
77-
bool success = dai::downloadModelsFromZoo(yamlFolder, cacheFolder, apiKey);
77+
bool success = dai::downloadModelsFromZoo(yamlFolder, cacheFolder, apiKey, format);
7878
if(!success) {
7979
std::cerr << "Failed to download all models from " << yamlFolder << std::endl;
8080
return EXIT_FAILURE;

0 commit comments

Comments
 (0)