Skip to content

Commit b1a645a

Browse files
committed
poc agent progress print
1 parent 24a5a93 commit b1a645a

File tree

1 file changed

+116
-6
lines changed

1 file changed

+116
-6
lines changed

src/modelzoo/Zoo.cpp

Lines changed: 116 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
#include <cctype>
44
#include <filesystem>
55
#include <iostream>
6+
#include <memory>
7+
#include <thread>
8+
#include <atomic>
9+
#include <mutex>
610
#include <nlohmann/json.hpp>
11+
#include <nlohmann/json_fwd.hpp>
712

813
#include "utility/Environment.hpp"
914
#include "utility/Logging.hpp"
@@ -12,6 +17,7 @@
1217

1318
#ifdef DEPTHAI_ENABLE_CURL
1419
#include <cpr/api.h>
20+
#include <cpr/cprtypes.h>
1521
#include <cpr/parameters.h>
1622
#include <cpr/status_codes.h>
1723
#endif
@@ -96,8 +102,11 @@ class ZooManager {
96102

97103
/**
98104
* @brief Download model from model zoo
105+
*
106+
* @param responseJson: JSON with download links
107+
* @param cprCallback: Progress callback
99108
*/
100-
void downloadModel(const nlohmann::json& responseJson);
109+
void downloadModel(const nlohmann::json& responseJson, std::unique_ptr<cpr::ProgressCallback> cprCallback);
101110

102111
/**
103112
* @brief Return path to model in cache
@@ -328,7 +337,7 @@ nlohmann::json ZooManager::fetchModelDownloadLinks() {
328337
return responseJson;
329338
}
330339

331-
void ZooManager::downloadModel(const nlohmann::json& responseJson) {
340+
void ZooManager::downloadModel(const nlohmann::json& responseJson, std::unique_ptr<cpr::ProgressCallback> cprCallback) {
332341
// Extract download links from response
333342
auto downloadLinks = responseJson["download_links"].get<std::vector<std::string>>();
334343
auto downloadHash = responseJson["hash"].get<std::string>();
@@ -346,7 +355,7 @@ void ZooManager::downloadModel(const nlohmann::json& responseJson) {
346355

347356
// Download all files and store them in cache folder
348357
for(const auto& downloadLink : downloadLinks) {
349-
cpr::Response downloadResponse = cpr::Get(cpr::Url(downloadLink));
358+
cpr::Response downloadResponse = cpr::Get(cpr::Url(downloadLink), *cprCallback);
350359
if(checkIsErrorModelDownload(downloadResponse)) {
351360
removeModelCacheFolder();
352361
throw std::runtime_error(generateErrorMessageModelDownload(downloadResponse));
@@ -395,7 +404,11 @@ std::string ZooManager::loadModelFromCache() const {
395404
return std::filesystem::absolute(folderFiles[0]).string();
396405
}
397406

398-
std::string getModelFromZoo(const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
407+
std::string getModelFromZoo(const NNModelDescription& modelDescription,
408+
bool useCached,
409+
const std::string& cacheDirectory,
410+
const std::string& apiKey,
411+
const std::function<void(cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, intptr_t)>& progressCallback) {
399412
// Check if model description is valid
400413
if(!modelDescription.check()) throw std::runtime_error("Invalid model description:\n" + modelDescription.toString());
401414

@@ -466,9 +479,21 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription, bool use
466479
// Create cache folder
467480
zooManager.createCacheFolder();
468481

482+
// Create download progress callback
483+
std::unique_ptr<cpr::ProgressCallback> cprCallback;
484+
if(progressCallback) {
485+
cprCallback = std::make_unique<cpr::ProgressCallback>(
486+
[&](cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) {
487+
progressCallback(downloadTotal, downloadNow, uploadTotal, uploadNow, userdata);
488+
return true;
489+
});
490+
} else {
491+
cprCallback = std::make_unique<cpr::ProgressCallback>([](cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, cpr::cpr_off_t, intptr_t) { return true; });
492+
}
493+
469494
// Download model
470495
logger::info("Downloading model from model zoo");
471-
zooManager.downloadModel(responseJson);
496+
zooManager.downloadModel(responseJson, std::move(cprCallback));
472497

473498
// Store model as yaml in the cache folder
474499
std::string yamlPath = combinePaths(zooManager.getModelCacheFolderPath(cacheDirectory), "model.yaml");
@@ -479,6 +504,84 @@ std::string getModelFromZoo(const NNModelDescription& modelDescription, bool use
479504
return modelPath;
480505
}
481506

507+
std::string getModelFromZoo(const NNModelDescription& modelDescription, bool useCached, const std::string& cacheDirectory, const std::string& apiKey) {
508+
return getModelFromZoo(modelDescription, useCached, cacheDirectory, apiKey, nullptr);
509+
}
510+
511+
struct JsonDownloadProgressManager {
512+
JsonDownloadProgressManager(size_t updateIntervalMs = 1000) : updateIntervalMs(updateIntervalMs) {
513+
pause();
514+
}
515+
516+
size_t updateIntervalMs;
517+
518+
std::string model;
519+
size_t bytesDownloaded = 0;
520+
size_t bytesTotal = 0;
521+
std::mutex mutex;
522+
523+
std::thread thread;
524+
std::atomic<bool> running;
525+
std::atomic<bool> started;
526+
std::atomic<bool> firstUpdate;
527+
528+
bool pause() {
529+
std::lock_guard<std::mutex> lock(mutex);
530+
started = false;
531+
firstUpdate = true;
532+
return true;
533+
}
534+
535+
bool resume() {
536+
std::lock_guard<std::mutex> lock(mutex);
537+
started = true;
538+
firstUpdate = true;
539+
return true;
540+
}
541+
542+
void update(cpr::cpr_off_t downloadTotal, cpr::cpr_off_t downloadNow, cpr::cpr_off_t uploadTotal, cpr::cpr_off_t uploadNow, intptr_t userdata) {
543+
(void)uploadTotal;
544+
(void)uploadNow;
545+
(void)userdata;
546+
std::lock_guard<std::mutex> lock(mutex);
547+
this->bytesDownloaded = downloadNow;
548+
this->bytesTotal = downloadTotal;
549+
firstUpdate = false;
550+
}
551+
552+
nlohmann::json getJson() {
553+
std::lock_guard<std::mutex> lock(mutex);
554+
nlohmann::json json;
555+
json["model"] = model;
556+
json["bytes_downloaded"] = bytesDownloaded;
557+
json["bytes_total"] = bytesTotal;
558+
return json;
559+
}
560+
561+
void setModel(const std::string& model) {
562+
std::lock_guard<std::mutex> lock(mutex);
563+
this->model = model;
564+
}
565+
566+
void startThread() {
567+
running = true;
568+
thread = std::thread([this]() {
569+
while(running) {
570+
std::this_thread::sleep_for(std::chrono::milliseconds(updateIntervalMs));
571+
if(!started || firstUpdate) continue;
572+
auto json = getJson();
573+
std::string jsonStr = json.dump();
574+
std::cout << jsonStr << std::endl;
575+
}
576+
});
577+
}
578+
579+
void stopThread() {
580+
running = false;
581+
thread.join();
582+
}
583+
};
584+
482585
bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDirectory, const std::string& apiKey) {
483586
logger::info("Downloading models from zoo");
484587
// Make sure 'path' exists
@@ -497,6 +600,9 @@ bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDire
497600
}
498601
}
499602

603+
auto jsonDownloadProgressManager = std::make_unique<JsonDownloadProgressManager>(100);
604+
jsonDownloadProgressManager->startThread();
605+
auto callback = std::bind(&JsonDownloadProgressManager::update, jsonDownloadProgressManager.get(), std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5);
500606
// Download models from yaml files
501607
int numSuccess = 0, numFail = 0;
502608
for(size_t i = 0; i < models.size(); ++i) {
@@ -505,17 +611,21 @@ bool downloadModelsFromZoo(const std::string& path, const std::string& cacheDire
505611

506612
// Download model - ignore the returned model path here == we are only interested in downloading the model
507613
try {
614+
jsonDownloadProgressManager->setModel(modelName);
615+
jsonDownloadProgressManager->resume();
508616
logger::info("Downloading model [{} / {}]: {}", i + 1, models.size(), modelName);
509617
auto modelDescription = NNModelDescription::fromYamlFile(modelName, path);
510-
getModelFromZoo(modelDescription, true, cacheDirectory, apiKey);
618+
getModelFromZoo(modelDescription, true, cacheDirectory, apiKey, callback);
511619
logger::info("Downloaded model [{} / {}]: {}", i + 1, models.size(), modelName);
512620
numSuccess++;
513621
} catch(const std::exception& e) {
514622
logger::error("Failed to download model [{} / {}]: {} in folder {}\n{}", i + 1, models.size(), modelName, path, e.what());
515623
numFail++;
516624
}
625+
jsonDownloadProgressManager->pause();
517626
}
518627

628+
jsonDownloadProgressManager->stopThread();
519629
logger::info("Downloaded {} models from folder {} | {} failed.", numSuccess, path, numFail);
520630
return numFail == 0;
521631
}

0 commit comments

Comments
 (0)