Skip to content

Commit 465ae31

Browse files
committed
Merge branch 'v3_cleanup' into lnotspotl/internal_xlink
2 parents e14f389 + 24231d4 commit 465ae31

File tree

13 files changed

+205
-36
lines changed

13 files changed

+205
-36
lines changed

.github/workflows/main.workflow.yml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,14 @@ on:
66
branches:
77
- main
88
- develop
9-
- rvc3_support
10-
- rvc3_develop
9+
- v3_develop
1110
tags:
1211
- 'v2*'
1312
pull_request:
1413
branches:
1514
- main
1615
- develop
17-
- rvc3_support
18-
- rvc3_develop
16+
- v3_develop
1917

2018
jobs:
2119

@@ -58,7 +56,6 @@ jobs:
5856
- name: Run clang-tidy
5957
run: cmake --build build --parallel 4
6058

61-
6259
build:
6360
runs-on: ${{ matrix.os }}
6461
strategy:

CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,7 @@ target_link_libraries(${TARGET_CORE_NAME}
698698
semver::semver
699699
magic_enum::magic_enum
700700
liblzma::liblzma
701+
lz4::lz4
701702
)
702703

703704
if(DEPTHAI_ENABLE_MP4V2)
@@ -1146,7 +1147,13 @@ include(CMakePackageConfigHelpers)
11461147

11471148
# Add additional targets to export group
11481149
if(NOT BUILD_SHARED_LIBS)
1149-
list(APPEND targets_to_export ${DEPTHAI_RESOURCE_LIBRARY_NAME} cmrc-base foxglove_websocket messages XLink)
1150+
list(APPEND targets_to_export ${DEPTHAI_RESOURCE_LIBRARY_NAME} cmrc-base XLink)
1151+
if(DEPTHAI_ENABLE_PROTOBUF)
1152+
list(APPEND targets_to_export messages)
1153+
endif()
1154+
if(DEPTHAI_ENABLE_REMOTE_CONNECTION)
1155+
list(APPEND targets_to_export foxglove_websocket)
1156+
endif()
11501157
endif()
11511158

11521159
# Export targets (capability to import current build directory)

bindings/python/src/modelzoo/ZooBindings.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@ void ZooBindings::bind(pybind11::module& m, void* pCallstack) {
3232
py::arg("useCached") = true,
3333
py::arg("cacheDirectory") = "",
3434
py::arg("apiKey") = "",
35+
py::arg("progressFormat") = "none",
3536
DOC(dai, getModelFromZoo));
3637

3738
m.def("downloadModelsFromZoo",
3839
downloadModelsFromZoo,
3940
py::arg("path"),
4041
py::arg("cacheDirectory") = "",
4142
py::arg("apiKey") = "",
43+
py::arg("progressFormat") = "none",
4244
DOC(dai, downloadModelsFromZoo));
4345

4446
// Bind NNModelDescription

include/depthai/modelzoo/Zoo.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,14 @@ struct NNModelDescription {
8686
* DEPTHAI_ZOO_CACHE_PATH environment variable and uses that if set, otherwise the default value is used (see getDefaultCachePath).
8787
* @param apiKey: API key for the model zoo, default is "". If apiKey is set to "", this function checks the DEPTHAI_ZOO_API_KEY environment variable and uses
8888
* that if set. Otherwise, no API key is used.
89+
* @param progressFormat: Format to use for progress output (possible values: pretty, json, none), default is "none"
8990
* @return std::string: Path to the model in cache
9091
*/
9192
std::string getModelFromZoo(const NNModelDescription& modelDescription,
9293
bool useCached = true,
9394
const std::string& cacheDirectory = "",
94-
const std::string& apiKey = "");
95+
const std::string& apiKey = "",
96+
const std::string& progressFormat = "none");
9597

9698
/**
9799
* @brief Helper function allowing one to download all models specified in yaml files in the given path and store them in the cache directory
@@ -101,9 +103,10 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription,
101103
* DEPTHAI_ZOO_CACHE_PATH environment variable and uses that if set, otherwise the default is used (see getDefaultCachePath).
102104
* @param apiKey: API key for the model zoo, default is "". If apiKey is set to "", this function checks the DEPTHAI_ZOO_API_KEY environment variable and uses
103105
* that if set. Otherwise, no API key is used.
106+
* @param progressFormat: Format to use for progress output (possible values: pretty, json, none), default is "none"
104107
* @return bool: True if all models were downloaded successfully, false otherwise
105108
*/
106-
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory = "", const std::string& apiKey = "");
109+
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory = "", const std::string& apiKey = "", const std::string& progressFormat = "none");
107110

108111
std::ostream& operator<<(std::ostream& os, const NNModelDescription& modelDescription);
109112

src/modelzoo/Zoo.cpp

Lines changed: 163 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
#include <cctype>
44
#include <filesystem>
55
#include <iostream>
6+
#include <memory>
7+
#include <mutex>
68
#include <nlohmann/json.hpp>
9+
#include <nlohmann/json_fwd.hpp>
710

811
#include "utility/Environment.hpp"
912
#include "utility/Logging.hpp"
@@ -12,6 +15,7 @@
1215

1316
#ifdef DEPTHAI_ENABLE_CURL
1417
#include <cpr/api.h>
18+
#include <cpr/cprtypes.h>
1519
#include <cpr/parameters.h>
1620
#include <cpr/status_codes.h>
1721
#endif
@@ -96,8 +100,11 @@ class ZooManager {
96100

97101
/**
98102
* @brief Download model from model zoo
103+
*
104+
* @param responseJson: JSON with download links
105+
* @param cprCallback: Progress callback
99106
*/
100-
void downloadModel(const nlohmann::json& responseJson);
107+
void downloadModel(const nlohmann::json& responseJson, std::unique_ptr<cpr::ProgressCallback> cprCallback);
101108

102109
/**
103110
* @brief Return path to model in cache
@@ -328,7 +335,7 @@ nlohmann::json ZooManager::fetchModelDownloadLinks() {
328335
return responseJson;
329336
}
330337

331-
void ZooManager::downloadModel(const nlohmann::json& responseJson) {
338+
void ZooManager::downloadModel(const nlohmann::json& responseJson, std::unique_ptr<cpr::ProgressCallback> cprCallback) {
332339
// Extract download links from response
333340
auto downloadLinks = responseJson["download_links"].get<std::vector<std::string>>();
334341
auto downloadHash = responseJson["hash"].get<std::string>();
@@ -346,7 +353,7 @@ void ZooManager::downloadModel(const nlohmann::json& responseJson) {
346353

347354
// Download all files and store them in cache folder
348355
for(const auto& downloadLink : downloadLinks) {
349-
cpr::Response downloadResponse = cpr::Get(cpr::Url(downloadLink));
356+
cpr::Response downloadResponse = cpr::Get(cpr::Url(downloadLink), *cprCallback);
350357
if(checkIsErrorModelDownload(downloadResponse)) {
351358
removeModelCacheFolder();
352359
throw std::runtime_error(generateErrorMessageModelDownload(downloadResponse));
@@ -395,7 +402,145 @@ std::string ZooManager::loadModelFromCache() const {
395402
return std::filesystem::absolute(folderFiles[0]).string();
396403
}
397404

398-
std::string getModelFromZoo(const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
405+
class CprCallback {
406+
public:
407+
virtual ~CprCallback() = default;
408+
CprCallback(const std::string& modelName) : modelName(modelName) {}
409+
410+
virtual void cprCallback(
411+
cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) = 0;
412+
413+
virtual std::unique_ptr<cpr::ProgressCallback> getCprProgressCallback() {
414+
return std::make_unique<cpr::ProgressCallback>(
415+
[this](cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) {
416+
this->cprCallback(downloadTotal, downloadNow, uploadTotal, uploadNow, userdata);
417+
return true;
418+
});
419+
}
420+
421+
protected:
422+
std::string modelName;
423+
};
424+
425+
class JsonCprCallback : public CprCallback {
426+
constexpr static long long PRINT_INTERVAL_MS = 100;
427+
428+
public:
429+
JsonCprCallback(const std::string& modelName) : CprCallback(modelName) {
430+
startTime = std::chrono::steady_clock::time_point::min();
431+
}
432+
433+
void print(long downloadTotal, long downloadNow, const std::string& modelName) {
434+
nlohmann::json json = {
435+
{"download_total", downloadTotal},
436+
{"download_now", downloadNow},
437+
{"model_name", modelName},
438+
};
439+
std::cout << json.dump() << std::endl;
440+
}
441+
442+
void cprCallback(
443+
cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) override {
444+
(void)uploadTotal;
445+
(void)uploadNow;
446+
(void)userdata;
447+
448+
bool firstCall = startTime == std::chrono::steady_clock::time_point::min();
449+
if(firstCall || downloadTotal == 0) {
450+
startTime = std::chrono::steady_clock::now();
451+
}
452+
453+
bool shouldPrint = std::chrono::steady_clock::now() - startTime > std::chrono::milliseconds(PRINT_INTERVAL_MS) || this->downloadTotal != downloadTotal;
454+
455+
if(shouldPrint) {
456+
print(downloadTotal, downloadNow, modelName);
457+
startTime = std::chrono::steady_clock::now();
458+
}
459+
460+
this->downloadTotal = downloadTotal;
461+
this->downloadNow = downloadNow;
462+
}
463+
464+
~JsonCprCallback() override {
465+
if(downloadTotal != 0) {
466+
print(downloadTotal, downloadNow, modelName);
467+
}
468+
}
469+
470+
private:
471+
long downloadTotal = 0;
472+
long downloadNow = 0;
473+
std::chrono::steady_clock::time_point startTime;
474+
};
475+
476+
class PrettyCprCallback : public CprCallback {
477+
public:
478+
PrettyCprCallback(const std::string& modelName) : CprCallback(modelName), finalProgressPrinted(false) {}
479+
480+
void cprCallback(
481+
cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) override {
482+
(void)uploadTotal;
483+
(void)uploadNow;
484+
(void)userdata;
485+
486+
if(finalProgressPrinted) return;
487+
488+
if(downloadTotal > 0) {
489+
float progress = static_cast<float>(downloadNow) / downloadTotal;
490+
int barWidth = 50;
491+
int pos = static_cast<int>(barWidth * progress);
492+
493+
std::cout << "\rDownloading " << modelName << " [";
494+
for(int i = 0; i < barWidth; ++i) {
495+
if(i < pos)
496+
std::cout << "=";
497+
else if(i == pos)
498+
std::cout << ">";
499+
else
500+
std::cout << " ";
501+
}
502+
std::cout << "] " << std::fixed << std::setprecision(3) << progress * 100.0f << "% " << downloadNow / 1024.0f / 1024.0f << "/"
503+
<< downloadTotal / 1024.0f / 1024.0f << " MB";
504+
505+
if(downloadNow == downloadTotal) {
506+
std::cout << std::endl;
507+
finalProgressPrinted = true;
508+
} else {
509+
std::cout << "\r";
510+
std::cout.flush();
511+
}
512+
}
513+
}
514+
515+
private:
516+
bool finalProgressPrinted;
517+
};
518+
519+
class NoneCprCallback : public CprCallback {
520+
public:
521+
NoneCprCallback(const std::string& modelName) : CprCallback(modelName) {}
522+
523+
void cprCallback(cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, intptr_t) override {
524+
// Do nothing
525+
}
526+
};
527+
528+
std::unique_ptr<CprCallback> getCprCallback(const std::string& format, const std::string& name) {
529+
if(format == "json") {
530+
return std::make_unique<JsonCprCallback>(name);
531+
} else if(format == "pretty") {
532+
return std::make_unique<PrettyCprCallback>(name);
533+
} else if(format == "none") {
534+
return std::make_unique<NoneCprCallback>(name);
535+
}
536+
throw std::runtime_error("Invalid format: " + format);
537+
}
538+
539+
std::string getModelFromZoo(const NNModelDescription& modelDescription,
540+
bool useCached,
541+
const std::string& cacheDirectory,
542+
const std::string& apiKey,
543+
const std::string& progressFormat) {
399544
// Check if model description is valid
400545
if(!modelDescription.check()) throw std::runtime_error("Invalid model description:\n" + modelDescription.toString());
401546

@@ -466,9 +611,12 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription, bool use
466611
// Create cache folder
467612
zooManager.createCacheFolder();
468613

614+
// Create download progress callback
615+
std::unique_ptr<CprCallback> cprCallback = getCprCallback(progressFormat, modelDescription.globalMetadataEntryName.size() > 0 ? modelDescription.globalMetadataEntryName : modelDescription.model);
616+
469617
// Download model
470618
logger::info("Downloading model from model zoo");
471-
zooManager.downloadModel(responseJson);
619+
zooManager.downloadModel(responseJson, cprCallback->getCprProgressCallback());
472620

473621
// Store model as yaml in the cache folder
474622
std::string yamlPath = combinePaths(zooManager.getModelCacheFolderPath(cacheDirectory), "model.yaml");
@@ -479,7 +627,7 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription, bool use
479627
return modelPath;
480628
}
481629

482-
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey) {
630+
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey, const std::string& progressFormat) {
483631
logger::info("Downloading models from zoo");
484632
// Make sure 'path' exists
485633
if(!std::filesystem::exists(path)) throw std::runtime_error("Path does not exist: " + path);
@@ -507,7 +655,7 @@ bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDire
507655
try {
508656
logger::info("Downloading model [{} / {}]: {}", i + 1, models.size(), modelName);
509657
auto modelDescription = NNModelDescription::fromYamlFile(modelName, path);
510-
getModelFromZoo(modelDescription, true, cacheDirectory, apiKey);
658+
getModelFromZoo(modelDescription, true, cacheDirectory, apiKey, progressFormat);
511659
logger::info("Downloaded model [{} / {}]: {}", i + 1, models.size(), modelName);
512660
numSuccess++;
513661
} catch(const std::exception& e) {
@@ -546,18 +694,24 @@ std::string ZooManager::getGlobalMetadataFilePath() const {
546694

547695
#else
548696

549-
std::string getModelFromZoo(const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
697+
std::string getModelFromZoo(const NNModelDescription& modelDescription,
698+
bool useCached,
699+
const std::string& cacheDirectory,
700+
const std::string& apiKey,
701+
const std::string& progressFormat) {
550702
(void)modelDescription;
551703
(void)useCached;
552704
(void)cacheDirectory;
553705
(void)apiKey;
706+
(void)progressFormat;
554707
throw std::runtime_error("getModelFromZoo requires libcurl to be enabled. Please recompile DepthAI with libcurl enabled.");
555708
}
556709

557-
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey) {
710+
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey, const std::string& progressFormat) {
558711
(void)path;
559712
(void)cacheDirectory;
560713
(void)apiKey;
714+
(void)progressFormat;
561715
throw std::runtime_error("downloadModelsFromZoo requires libcurl to be enabled. Please recompile DepthAI with libcurl enabled.");
562716
}
563717

0 commit comments

Comments
 (0)