Skip to content

Commit 789b991

Browse files
authored
Merge pull request #1305 from luxonis/v3_zoo04
Zoo features, agent team
2 parents 54acc85 + f33d071 commit 789b991

File tree

4 files changed

+185
-18
lines changed

4 files changed

+185
-18
lines changed

bindings/python/src/modelzoo/ZooBindings.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@ void ZooBindings::bind(pybind11::module& m, void* pCallstack) {
3232
py::arg("useCached") = true,
3333
py::arg("cacheDirectory") = "",
3434
py::arg("apiKey") = "",
35+
py::arg("progressFormat") = "none",
3536
DOC(dai, getModelFromZoo));
3637

3738
m.def("downloadModelsFromZoo",
3839
downloadModelsFromZoo,
3940
py::arg("path"),
4041
py::arg("cacheDirectory") = "",
4142
py::arg("apiKey") = "",
43+
py::arg("progressFormat") = "none",
4244
DOC(dai, downloadModelsFromZoo));
4345

4446
// Bind NNModelDescription

include/depthai/modelzoo/Zoo.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,14 @@ struct NNModelDescription {
8686
* DEPTHAI_ZOO_CACHE_PATH environment variable and uses that if set, otherwise the default value is used (see getDefaultCachePath).
8787
* @param apiKey: API key for the model zoo, default is "". If apiKey is set to "", this function checks the DEPTHAI_ZOO_API_KEY environment variable and uses
8888
* that if set. Otherwise, no API key is used.
89+
* @param progressFormat: Format to use for progress output (possible values: pretty, json, none), default is "none"
8990
* @return std::string: Path to the model in cache
9091
*/
9192
std::string getModelFromZoo(const NNModelDescription& modelDescription,
9293
bool useCached = true,
9394
const std::string& cacheDirectory = "",
94-
const std::string& apiKey = "");
95+
const std::string& apiKey = "",
96+
const std::string& progressFormat = "none");
9597

9698
/**
9799
* @brief Helper function allowing one to download all models specified in yaml files in the given path and store them in the cache directory
@@ -101,9 +103,10 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription,
101103
* DEPTHAI_ZOO_CACHE_PATH environment variable and uses that if set, otherwise the default is used (see getDefaultCachePath).
102104
* @param apiKey: API key for the model zoo, default is "". If apiKey is set to "", this function checks the DEPTHAI_ZOO_API_KEY environment variable and uses
103105
* that if set. Otherwise, no API key is used.
106+
* @param progressFormat: Format to use for progress output (possible values: pretty, json, none), default is "none"
104107
* @return bool: True if all models were downloaded successfully, false otherwise
105108
*/
106-
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory = "", const std::string& apiKey = "");
109+
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory = "", const std::string& apiKey = "", const std::string& progressFormat = "none");
107110

108111
std::ostream& operator<<(std::ostream& os, const NNModelDescription& modelDescription);
109112

src/modelzoo/Zoo.cpp

Lines changed: 163 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
#include <cctype>
44
#include <filesystem>
55
#include <iostream>
6+
#include <memory>
7+
#include <mutex>
68
#include <nlohmann/json.hpp>
9+
#include <nlohmann/json_fwd.hpp>
710

811
#include "utility/Environment.hpp"
912
#include "utility/Logging.hpp"
@@ -12,6 +15,7 @@
1215

1316
#ifdef DEPTHAI_ENABLE_CURL
1417
#include <cpr/api.h>
18+
#include <cpr/cprtypes.h>
1519
#include <cpr/parameters.h>
1620
#include <cpr/status_codes.h>
1721
#endif
@@ -96,8 +100,11 @@ class ZooManager {
96100

97101
/**
98102
* @brief Download model from model zoo
103+
*
104+
* @param responseJson: JSON with download links
105+
* @param cprCallback: Progress callback
99106
*/
100-
void downloadModel(const nlohmann::json& responseJson);
107+
void downloadModel(const nlohmann::json& responseJson, std::unique_ptr<cpr::ProgressCallback> cprCallback);
101108

102109
/**
103110
* @brief Return path to model in cache
@@ -328,7 +335,7 @@ nlohmann::json ZooManager::fetchModelDownloadLinks() {
328335
return responseJson;
329336
}
330337

331-
void ZooManager::downloadModel(const nlohmann::json& responseJson) {
338+
void ZooManager::downloadModel(const nlohmann::json& responseJson, std::unique_ptr<cpr::ProgressCallback> cprCallback) {
332339
// Extract download links from response
333340
auto downloadLinks = responseJson["download_links"].get<std::vector<std::string>>();
334341
auto downloadHash = responseJson["hash"].get<std::string>();
@@ -346,7 +353,7 @@ void ZooManager::downloadModel(const nlohmann::json& responseJson) {
346353

347354
// Download all files and store them in cache folder
348355
for(const auto& downloadLink : downloadLinks) {
349-
cpr::Response downloadResponse = cpr::Get(cpr::Url(downloadLink));
356+
cpr::Response downloadResponse = cpr::Get(cpr::Url(downloadLink), *cprCallback);
350357
if(checkIsErrorModelDownload(downloadResponse)) {
351358
removeModelCacheFolder();
352359
throw std::runtime_error(generateErrorMessageModelDownload(downloadResponse));
@@ -395,7 +402,145 @@ std::string ZooManager::loadModelFromCache() const {
395402
return std::filesystem::absolute(folderFiles[0]).string();
396403
}
397404

398-
std::string getModelFromZoo(const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
405+
class CprCallback {
406+
public:
407+
virtual ~CprCallback() = default;
408+
CprCallback(const std::string& modelName) : modelName(modelName) {}
409+
410+
virtual void cprCallback(
411+
cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) = 0;
412+
413+
virtual std::unique_ptr<cpr::ProgressCallback> getCprProgressCallback() {
414+
return std::make_unique<cpr::ProgressCallback>(
415+
[this](cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) {
416+
this->cprCallback(downloadTotal, downloadNow, uploadTotal, uploadNow, userdata);
417+
return true;
418+
});
419+
}
420+
421+
protected:
422+
std::string modelName;
423+
};
424+
425+
class JsonCprCallback : public CprCallback {
426+
constexpr static long long PRINT_INTERVAL_MS = 100;
427+
428+
public:
429+
JsonCprCallback(const std::string& modelName) : CprCallback(modelName) {
430+
startTime = std::chrono::steady_clock::time_point::min();
431+
}
432+
433+
void print(long downloadTotal, long downloadNow, const std::string& modelName) {
434+
nlohmann::json json = {
435+
{"download_total", downloadTotal},
436+
{"download_now", downloadNow},
437+
{"model_name", modelName},
438+
};
439+
std::cout << json.dump() << std::endl;
440+
}
441+
442+
void cprCallback(
443+
cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) override {
444+
(void)uploadTotal;
445+
(void)uploadNow;
446+
(void)userdata;
447+
448+
bool firstCall = startTime == std::chrono::steady_clock::time_point::min();
449+
if(firstCall || downloadTotal == 0) {
450+
startTime = std::chrono::steady_clock::now();
451+
}
452+
453+
bool shouldPrint = std::chrono::steady_clock::now() - startTime > std::chrono::milliseconds(PRINT_INTERVAL_MS) || this->downloadTotal != downloadTotal;
454+
455+
if(shouldPrint) {
456+
print(downloadTotal, downloadNow, modelName);
457+
startTime = std::chrono::steady_clock::now();
458+
}
459+
460+
this->downloadTotal = downloadTotal;
461+
this->downloadNow = downloadNow;
462+
}
463+
464+
~JsonCprCallback() override {
465+
if(downloadTotal != 0) {
466+
print(downloadTotal, downloadNow, modelName);
467+
}
468+
}
469+
470+
private:
471+
long downloadTotal = 0;
472+
long downloadNow = 0;
473+
std::chrono::steady_clock::time_point startTime;
474+
};
475+
476+
class PrettyCprCallback : public CprCallback {
477+
public:
478+
PrettyCprCallback(const std::string& modelName) : CprCallback(modelName), finalProgressPrinted(false) {}
479+
480+
void cprCallback(
481+
cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) override {
482+
(void)uploadTotal;
483+
(void)uploadNow;
484+
(void)userdata;
485+
486+
if(finalProgressPrinted) return;
487+
488+
if(downloadTotal > 0) {
489+
float progress = static_cast<float>(downloadNow) / downloadTotal;
490+
int barWidth = 50;
491+
int pos = static_cast<int>(barWidth * progress);
492+
493+
std::cout << "\rDownloading " << modelName << " [";
494+
for(int i = 0; i < barWidth; ++i) {
495+
if(i < pos)
496+
std::cout << "=";
497+
else if(i == pos)
498+
std::cout << ">";
499+
else
500+
std::cout << " ";
501+
}
502+
std::cout << "] " << std::fixed << std::setprecision(3) << progress * 100.0f << "% " << downloadNow / 1024.0f / 1024.0f << "/"
503+
<< downloadTotal / 1024.0f / 1024.0f << " MB";
504+
505+
if(downloadNow == downloadTotal) {
506+
std::cout << std::endl;
507+
finalProgressPrinted = true;
508+
} else {
509+
std::cout << "\r";
510+
std::cout.flush();
511+
}
512+
}
513+
}
514+
515+
private:
516+
bool finalProgressPrinted;
517+
};
518+
519+
class NoneCprCallback : public CprCallback {
520+
public:
521+
NoneCprCallback(const std::string& modelName) : CprCallback(modelName) {}
522+
523+
void cprCallback(cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, intptr_t) override {
524+
// Do nothing
525+
}
526+
};
527+
528+
std::unique_ptr<CprCallback> getCprCallback(const std::string& format, const std::string& name) {
529+
if(format == "json") {
530+
return std::make_unique<JsonCprCallback>(name);
531+
} else if(format == "pretty") {
532+
return std::make_unique<PrettyCprCallback>(name);
533+
} else if(format == "none") {
534+
return std::make_unique<NoneCprCallback>(name);
535+
}
536+
throw std::runtime_error("Invalid format: " + format);
537+
}
538+
539+
std::string getModelFromZoo(const NNModelDescription& modelDescription,
540+
bool useCached,
541+
const std::string& cacheDirectory,
542+
const std::string& apiKey,
543+
const std::string& progressFormat) {
399544
// Check if model description is valid
400545
if(!modelDescription.check()) throw std::runtime_error("Invalid model description:\n" + modelDescription.toString());
401546

@@ -466,9 +611,12 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription, bool use
466611
// Create cache folder
467612
zooManager.createCacheFolder();
468613

614+
// Create download progress callback
615+
std::unique_ptr<CprCallback> cprCallback = getCprCallback(progressFormat, modelDescription.globalMetadataEntryName.size() > 0 ? modelDescription.globalMetadataEntryName : modelDescription.model);
616+
469617
// Download model
470618
logger::info("Downloading model from model zoo");
471-
zooManager.downloadModel(responseJson);
619+
zooManager.downloadModel(responseJson, cprCallback->getCprProgressCallback());
472620

473621
// Store model as yaml in the cache folder
474622
std::string yamlPath = combinePaths(zooManager.getModelCacheFolderPath(cacheDirectory), "model.yaml");
@@ -479,7 +627,7 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription, bool use
479627
return modelPath;
480628
}
481629

482-
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey) {
630+
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey, const std::string& progressFormat) {
483631
logger::info("Downloading models from zoo");
484632
// Make sure 'path' exists
485633
if(!std::filesystem::exists(path)) throw std::runtime_error("Path does not exist: " + path);
@@ -507,7 +655,7 @@ bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDire
507655
try {
508656
logger::info("Downloading model [{} / {}]: {}", i + 1, models.size(), modelName);
509657
auto modelDescription = NNModelDescription::fromYamlFile(modelName, path);
510-
getModelFromZoo(modelDescription, true, cacheDirectory, apiKey);
658+
getModelFromZoo(modelDescription, true, cacheDirectory, apiKey, progressFormat);
511659
logger::info("Downloaded model [{} / {}]: {}", i + 1, models.size(), modelName);
512660
numSuccess++;
513661
} catch(const std::exception& e) {
@@ -546,18 +694,24 @@ std::string ZooManager::getGlobalMetadataFilePath() const {
546694

547695
#else
548696

549-
std::string getModelFromZoo(const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
697+
std::string getModelFromZoo(const NNModelDescription& modelDescription,
698+
bool useCached,
699+
const std::string& cacheDirectory,
700+
const std::string& apiKey,
701+
const std::string& progressFormat) {
550702
(void)modelDescription;
551703
(void)useCached;
552704
(void)cacheDirectory;
553705
(void)apiKey;
706+
(void)progressFormat;
554707
throw std::runtime_error("getModelFromZoo requires libcurl to be enabled. Please recompile DepthAI with libcurl enabled.");
555708
}
556709

557-
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey) {
710+
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey, const std::string& progressFormat) {
558711
(void)path;
559712
(void)cacheDirectory;
560713
(void)apiKey;
714+
(void)progressFormat;
561715
throw std::runtime_error("downloadModelsFromZoo requires libcurl to be enabled. Please recompile DepthAI with libcurl enabled.");
562716
}
563717

src/modelzoo/zoo_helper.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ int main(int argc, char* argv[]) {
3333
const std::string DEFAULT_DOWNLOAD_ENDPOINT = dai::modelzoo::getDownloadEndpoint();
3434
program.add_argument("--download_endpoint").default_value(DEFAULT_DOWNLOAD_ENDPOINT).help("Endpoint to use for downloading models");
3535

36+
const std::string FORMAT_DEFAULT = "pretty";
37+
program.add_argument("--format").default_value(FORMAT_DEFAULT).help("Format to use for output (possible values: pretty, json)");
38+
3639
program.add_argument("--verbose").default_value(false).implicit_value(true).help("Verbose output");
3740

3841
// Parse arguments
@@ -50,9 +53,10 @@ int main(int argc, char* argv[]) {
5053
auto apiKey = program.get<std::string>("--api_key");
5154
auto healthEndpoint = program.get<std::string>("--health_endpoint");
5255
auto downloadEndpoint = program.get<std::string>("--download_endpoint");
56+
auto format = program.get<std::string>("--format");
5357

5458
bool verbose = program.get<bool>("--verbose");
55-
if(!dai::utility::isEnvSet("DEPTHAI_LEVEL") && verbose) {
59+
if(!dai::utility::isEnvSet("DEPTHAI_LEVEL") && verbose && format == "pretty") {
5660
dai::Logging::getInstance().logger.set_level(spdlog::level::info);
5761
}
5862

@@ -61,19 +65,23 @@ int main(int argc, char* argv[]) {
6165
dai::modelzoo::setDownloadEndpoint(downloadEndpoint);
6266

6367
// Print arguments
64-
std::cout << "Downloading models defined in yaml files in folder: " << yamlFolder << std::endl;
65-
std::cout << "Downloading models into cache folder: " << cacheFolder << std::endl;
66-
if(!apiKey.empty()) {
67-
std::cout << "Using API key: " << apiKey << std::endl;
68+
if(format == "pretty") {
69+
std::cout << "Downloading models defined in yaml files in folder: " << yamlFolder << std::endl;
70+
std::cout << "Downloading models into cache folder: " << cacheFolder << std::endl;
71+
if(!apiKey.empty()) {
72+
std::cout << "Using API key: " << apiKey << std::endl;
73+
}
6874
}
6975

7076
// Download models
71-
bool success = dai::downloadModelsFromZoo(yamlFolder, cacheFolder, apiKey);
77+
bool success = dai::downloadModelsFromZoo(yamlFolder, cacheFolder, apiKey, format);
7278
if(!success) {
7379
std::cerr << "Failed to download all models from " << yamlFolder << std::endl;
7480
return EXIT_FAILURE;
7581
}
7682

77-
std::cout << "Successfully downloaded all models from " << yamlFolder << std::endl;
83+
if(format == "pretty") {
84+
std::cout << "Successfully downloaded all models from " << yamlFolder << std::endl;
85+
}
7886
return EXIT_SUCCESS;
7987
}

0 commit comments

Comments
 (0)