diff --git a/HeterogeneousCore/SonicCore/BuildFile.xml b/HeterogeneousCore/SonicCore/BuildFile.xml
index b0d5e2a08b98f..9796c4363c612 100644
--- a/HeterogeneousCore/SonicCore/BuildFile.xml
+++ b/HeterogeneousCore/SonicCore/BuildFile.xml
@@ -2,7 +2,8 @@
+
-
+i
diff --git a/HeterogeneousCore/SonicCore/interface/RetryActionBase.h b/HeterogeneousCore/SonicCore/interface/RetryActionBase.h
new file mode 100644
index 0000000000000..e3fc0bbb8af9a
--- /dev/null
+++ b/HeterogeneousCore/SonicCore/interface/RetryActionBase.h
@@ -0,0 +1,35 @@
+#ifndef HeterogeneousCore_SonicCore_RetryActionBase
+#define HeterogeneousCore_SonicCore_RetryActionBase
+
+#include "FWCore/PluginManager/interface/PluginFactory.h"
+#include "FWCore/ParameterSet/interface/ParameterSet.h"
+#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h"
+#include
+#include
+
+// Base class for retry actions
+class RetryActionBase {
+public:
+ RetryActionBase(const edm::ParameterSet& conf, SonicClientBase* client);
+ virtual ~RetryActionBase() = default;
+
+ bool shouldRetry() const { return shouldRetry_; } // Getter for shouldRetry_
+
+ virtual void retry() = 0; // Pure virtual function for execution logic
+ virtual void start() = 0; // Pure virtual function for execution logic for initialization
+
+protected:
+ void eval(); // interface for calling evaluate in client
+
+protected:
+ SonicClientBase* client_;
+ bool shouldRetry_; // Flag to track if further retries should happen
+};
+
+// Define the factory for creating retry actions
+using RetryActionFactory =
+ edmplugin::PluginFactory;
+
+#endif
+
+#define DEFINE_RETRY_ACTION(type) DEFINE_EDM_PLUGIN(RetryActionFactory, type, #type);
diff --git a/HeterogeneousCore/SonicCore/interface/SonicClientBase.h b/HeterogeneousCore/SonicCore/interface/SonicClientBase.h
index 47caaae8b2052..45a089701ed12 100644
--- a/HeterogeneousCore/SonicCore/interface/SonicClientBase.h
+++ b/HeterogeneousCore/SonicCore/interface/SonicClientBase.h
@@ -9,12 +9,15 @@
#include "HeterogeneousCore/SonicCore/interface/SonicDispatcherPseudoAsync.h"
#include
+#include
#include
#include
#include
enum class SonicMode { Sync = 1, Async = 2, PseudoAsync = 3 };
+class RetryActionBase;
+
class SonicClientBase {
public:
//constructor
@@ -54,14 +57,23 @@ class SonicClientBase {
SonicMode mode_;
bool verbose_;
std::unique_ptr dispatcher_;
- unsigned allowedTries_, tries_;
+ unsigned totalTries_;
std::optional holder_;
+ // Use a unique_ptr with a custom deleter to avoid incomplete type issues
+ struct RetryDeleter {
+ void operator()(RetryActionBase* ptr) const;
+ };
+
+ using RetryActionPtr = std::unique_ptr;
+ std::vector retryActions_;
+
//for logging/debugging
std::string debugName_, clientName_, fullDebugName_;
friend class SonicDispatcher;
friend class SonicDispatcherPseudoAsync;
+ friend class RetryActionBase;
};
#endif
diff --git a/HeterogeneousCore/SonicCore/plugins/BuildFile.xml b/HeterogeneousCore/SonicCore/plugins/BuildFile.xml
new file mode 100644
index 0000000000000..0ecf2187a0f82
--- /dev/null
+++ b/HeterogeneousCore/SonicCore/plugins/BuildFile.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
diff --git a/HeterogeneousCore/SonicCore/plugins/RetrySameServerAction.cc b/HeterogeneousCore/SonicCore/plugins/RetrySameServerAction.cc
new file mode 100644
index 0000000000000..9877013b93d5b
--- /dev/null
+++ b/HeterogeneousCore/SonicCore/plugins/RetrySameServerAction.cc
@@ -0,0 +1,30 @@
+#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h"
+#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h"
+
+class RetrySameServerAction : public RetryActionBase {
+public:
+ RetrySameServerAction(const edm::ParameterSet& pset, SonicClientBase* client)
+ : RetryActionBase(pset, client), allowedTries_(pset.getUntrackedParameter("allowedTries", 0)) {}
+
+ void start() override { tries_ = 0; };
+
+protected:
+ void retry() override;
+
+private:
+ unsigned allowedTries_, tries_;
+};
+
+void RetrySameServerAction::retry() {
+ ++tries_;
+ //if max retries has not been exceeded, call evaluate again
+ if (tries_ < allowedTries_) {
+ eval();
+ return;
+ } else {
+ shouldRetry_ = false; // Flip flag when max retries are reached
+ edm::LogInfo("RetrySameServerAction") << "Max retry attempts reached. No further retries.";
+ }
+}
+
+DEFINE_RETRY_ACTION(RetrySameServerAction)
diff --git a/HeterogeneousCore/SonicCore/src/RetryActionBase.cc b/HeterogeneousCore/SonicCore/src/RetryActionBase.cc
new file mode 100644
index 0000000000000..41b9a6186da2b
--- /dev/null
+++ b/HeterogeneousCore/SonicCore/src/RetryActionBase.cc
@@ -0,0 +1,15 @@
+#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h"
+
+// Constructor implementation
+RetryActionBase::RetryActionBase(const edm::ParameterSet& conf, SonicClientBase* client)
+ : client_(client), shouldRetry_(true) {}
+
+void RetryActionBase::eval() {
+ if (client_) {
+ client_->evaluate();
+ } else {
+ edm::LogError("RetryActionBase") << "Client pointer is null, cannot evaluate.";
+ }
+}
+
+EDM_REGISTER_PLUGINFACTORY(RetryActionFactory, "RetryActionFactory");
diff --git a/HeterogeneousCore/SonicCore/src/SonicClientBase.cc b/HeterogeneousCore/SonicCore/src/SonicClientBase.cc
index 745c51f17aaf3..9949d9d1f2ea2 100644
--- a/HeterogeneousCore/SonicCore/src/SonicClientBase.cc
+++ b/HeterogeneousCore/SonicCore/src/SonicClientBase.cc
@@ -1,18 +1,33 @@
#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h"
+#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h"
#include "FWCore/Utilities/interface/Exception.h"
#include "FWCore/ParameterSet/interface/allowedValues.h"
+// Custom deleter implementation
+void SonicClientBase::RetryDeleter::operator()(RetryActionBase* ptr) const { delete ptr; }
+
SonicClientBase::SonicClientBase(const edm::ParameterSet& params,
const std::string& debugName,
const std::string& clientName)
- : allowedTries_(params.getUntrackedParameter("allowedTries", 0)),
- debugName_(debugName),
- clientName_(clientName),
- fullDebugName_(debugName_) {
+ : debugName_(debugName), clientName_(clientName), fullDebugName_(debugName_) {
if (!clientName_.empty())
fullDebugName_ += ":" + clientName_;
+ const auto& retryPSetList = params.getParameter>("Retry");
std::string modeName(params.getParameter("mode"));
+
+ for (const auto& retryPSet : retryPSetList) {
+ const std::string& actionType = retryPSet.getParameter("retryType");
+
+ auto retryAction = RetryActionFactory::get()->create(actionType, retryPSet, this);
+ if (retryAction) {
+ //Convert to RetryActionPtr Type from raw pointer of retryAction
+ retryActions_.emplace_back(RetryActionPtr(retryAction.release()));
+ } else {
+ throw cms::Exception("Configuration") << "Unknown Retry type " << actionType << " for SonicClient: " << modeName;
+ }
+ }
+
if (modeName == "Sync")
setMode(SonicMode::Sync);
else if (modeName == "Async")
@@ -40,24 +55,30 @@ void SonicClientBase::start(edm::WaitingTaskWithArenaHolder holder) {
holder_ = std::move(holder);
}
-void SonicClientBase::start() { tries_ = 0; }
+void SonicClientBase::start() {
+ totalTries_ = 0;
+ // initialize all actions
+ for (const auto& action : retryActions_) {
+ action->start();
+ }
+}
void SonicClientBase::finish(bool success, std::exception_ptr eptr) {
//retries are only allowed if no exception was raised
if (!success and !eptr) {
- ++tries_;
- //if max retries has not been exceeded, call evaluate again
- if (tries_ < allowedTries_) {
- evaluate();
- //avoid calling doneWaiting() twice
- return;
- }
- //prepare an exception if exceeded
- else {
- edm::Exception ex(edm::errors::ExternalFailure);
- ex << "SonicCallFailed: call failed after max " << tries_ << " tries";
- eptr = make_exception_ptr(ex);
+ ++totalTries_;
+ for (const auto& action : retryActions_) {
+ if (action->shouldRetry()) {
+ action->retry(); // Call retry only if shouldRetry_ is true
+ return;
+ }
}
+ //prepare an exception if no more retries left
+ edm::LogInfo("SonicClientBase") << "SonicCallFailed: call failed, no retry actions available after " << totalTries_
+ << " tries.";
+ edm::Exception ex(edm::errors::ExternalFailure);
+ ex << "SonicCallFailed: call failed, no retry actions available after " << totalTries_ << " tries.";
+ eptr = make_exception_ptr(ex);
}
if (holder_) {
holder_->doneWaiting(eptr);
@@ -74,7 +95,20 @@ void SonicClientBase::fillBasePSetDescription(edm::ParameterSetDescription& desc
//restrict allowed values
desc.ifValue(edm::ParameterDescription("mode", "PseudoAsync", true),
edm::allowedValues("Sync", "Async", "PseudoAsync"));
- if (allowRetry)
- desc.addUntracked("allowedTries", 0);
+ if (allowRetry) {
+ // Defines the structure of each entry in the VPSet
+ edm::ParameterSetDescription retryDesc;
+ retryDesc.add("retryType", "RetrySameServerAction");
+ retryDesc.addUntracked("allowedTries", 0);
+
+ // Define a default retry action
+ edm::ParameterSet defaultRetry;
+ defaultRetry.addParameter("retryType", "RetrySameServerAction");
+ defaultRetry.addUntrackedParameter("allowedTries", 0);
+
+ // Add the VPSet with the default retry action
+ desc.addVPSet("Retry", retryDesc, {defaultRetry});
+ }
+ desc.add("sonicClientBase", desc);
desc.addUntracked("verbose", false);
}
diff --git a/HeterogeneousCore/SonicCore/test/DummyClient.h b/HeterogeneousCore/SonicCore/test/DummyClient.h
index ccef888ad9f7d..6504843926c0a 100644
--- a/HeterogeneousCore/SonicCore/test/DummyClient.h
+++ b/HeterogeneousCore/SonicCore/test/DummyClient.h
@@ -36,7 +36,7 @@ class DummyClient : public SonicClient {
this->output_ = this->input_ * factor_;
//simulate a failure
- if (this->tries_ < fails_)
+ if (this->totalTries_ < fails_)
this->finish(false);
else
this->finish(true);
diff --git a/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py b/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py
index 2cc429138b85c..bf7b44cb01519 100644
--- a/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py
+++ b/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py
@@ -19,15 +19,19 @@
process.options.numberOfThreads = 2
process.options.numberOfStreams = 0
-
process.dummySync = _moduleClass(_moduleName,
input = cms.int32(1),
Client = cms.PSet(
mode = cms.string("Sync"),
factor = cms.int32(-1),
wait = cms.int32(10),
- allowedTries = cms.untracked.uint32(0),
fails = cms.uint32(0),
+ Retry = cms.VPSet(
+ cms.PSet(
+ retryType = cms.string('RetrySameServerAction'),
+ allowedTries = cms.untracked.uint32(0)
+ )
+ )
),
)
@@ -37,8 +41,14 @@
mode = cms.string("PseudoAsync"),
factor = cms.int32(2),
wait = cms.int32(10),
- allowedTries = cms.untracked.uint32(0),
fails = cms.uint32(0),
+ Retry = cms.VPSet(
+ cms.PSet(
+ retryType = cms.string('RetrySameServerAction'),
+ allowedTries = cms.untracked.uint32(0)
+ )
+ )
+
),
)
@@ -48,32 +58,53 @@
mode = cms.string("Async"),
factor = cms.int32(5),
wait = cms.int32(10),
- allowedTries = cms.untracked.uint32(0),
fails = cms.uint32(0),
+ Retry = cms.VPSet(
+ cms.PSet(
+ retryType = cms.string('RetrySameServerAction'),
+ allowedTries = cms.untracked.uint32(0)
+ )
+ )
),
)
process.dummySyncRetry = process.dummySync.clone(
Client = dict(
wait = 2,
- allowedTries = 2,
fails = 1,
+ Retry = cms.VPSet(
+ cms.PSet(
+ retryType = cms.string('RetrySameServerAction'),
+ allowedTries = cms.untracked.uint32(2)
+ )
+ )
+
)
)
process.dummyPseudoAsyncRetry = process.dummyPseudoAsync.clone(
Client = dict(
wait = 2,
- allowedTries = 2,
fails = 1,
+ Retry = cms.VPSet(
+ cms.PSet(
+ retryType = cms.string('RetrySameServerAction'),
+ allowedTries = cms.untracked.uint32(2)
+ )
+ )
)
)
process.dummyAsyncRetry = process.dummyAsync.clone(
Client = dict(
wait = 2,
- allowedTries = 2,
fails = 1,
+ Retry = cms.VPSet(
+ cms.PSet(
+ allowedTries = cms.untracked.uint32(2),
+ retryType = cms.string('RetrySameServerAction')
+ )
+ )
)
)
diff --git a/HeterogeneousCore/SonicTriton/interface/TritonClient.h b/HeterogeneousCore/SonicTriton/interface/TritonClient.h
index df8f9b559427c..670e1a750bf0a 100644
--- a/HeterogeneousCore/SonicTriton/interface/TritonClient.h
+++ b/HeterogeneousCore/SonicTriton/interface/TritonClient.h
@@ -65,6 +65,7 @@ class TritonClient : public SonicClient {
bool handle_exception(F&& call);
void reportServerSideStats(const ServerSideStats& stats) const;
+ void updateServer(std::string serverName);
ServerSideStats summarizeServerStats(const inference::ModelStatistics& start_status,
const inference::ModelStatistics& end_status) const;
diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc
index ddcdff83448d0..f232414c1e9e5 100644
--- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc
+++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc
@@ -61,7 +61,7 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d
useSharedMemory_(params.getUntrackedParameter("useSharedMemory")),
compressionAlgo_(getCompressionAlgo(params.getUntrackedParameter("compression"))) {
options_.emplace_back(params.getParameter("modelName"));
- //get appropriate server for this model
+
edm::Service ts;
// We save the token to be able to notify the service in case of an exception in the evaluate method.
@@ -70,22 +70,9 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d
// create the context.
token_ = edm::ServiceRegistry::instance().presentToken();
- const auto& server =
- ts->serverInfo(options_[0].model_name_, params.getUntrackedParameter("preferredServer"));
- serverType_ = server.type;
- edm::LogInfo("TritonDiscovery") << debugName_ << " assigned server: " << server.url;
- //enforce sync mode for fallback CPU server to avoid contention
- //todo: could enforce async mode otherwise (unless mode was specified by user?)
- if (serverType_ == TritonServerType::LocalCPU)
- setMode(SonicMode::Sync);
- isLocal_ = serverType_ == TritonServerType::LocalCPU or serverType_ == TritonServerType::LocalGPU;
-
- //connect to the server
- TRITON_THROW_IF_ERROR(
- tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions),
- "TritonClient(): unable to create inference context",
- isLocal_);
-
+ //Connect to server
+ updateServer(params.getUntrackedParameter("preferredServer"));
+
//set options
options_[0].model_version_ = params.getParameter("modelVersion");
options_[0].client_timeout_ = params.getUntrackedParameter("timeout");
@@ -369,7 +356,7 @@ void TritonClient::getResults(const std::vector
//default case for sync and pseudo async
void TritonClient::evaluate() {
//undo previous signal from TritonException
- if (tries_ > 0) {
+ if (totalTries_ > 0) {
// If we are retrying then the evaluate method is called outside the frameworks TBB thread pool.
// So we need to setup the service token for the current thread to access the service registry.
edm::ServiceRegistry::Operate op(token_);
@@ -574,6 +561,26 @@ inference::ModelStatistics TritonClient::getServerSideStatus() const {
return inference::ModelStatistics{};
}
+void TritonClient::updateServer(std::string serverName){
+ //get appropriate server for this model
+ edm::Service ts;
+
+ const auto& server = ts->serverInfo(options_[0].model_name_, serverName);
+ serverType_ = server.type;
+ edm::LogInfo("TritonDiscovery") << debugName_ << " assigned server: " << server.url;
+ //enforce sync mode for fallback CPU server to avoid contention
+ //todo: could enforce async mode otherwise (unless mode was specified by user?)
+ if (serverType_ == TritonServerType::LocalCPU)
+ setMode(SonicMode::Sync);
+ isLocal_ = serverType_ == TritonServerType::LocalCPU or serverType_ == TritonServerType::LocalGPU;
+
+ //connect to the server
+ TRITON_THROW_IF_ERROR(
+ tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions),
+ "TritonClient(): unable to create inference context",
+ isLocal_);
+}
+
//for fillDescriptions
void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) {
edm::ParameterSetDescription descClient;
diff --git a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py
index 9cede0e496706..f27d7711665af 100644
--- a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py
+++ b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py
@@ -123,9 +123,14 @@
modelVersion = cms.string(""),
modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)),
verbose = cms.untracked.bool(options.verbose or options.verboseClient),
- allowedTries = cms.untracked.uint32(options.tries),
useSharedMemory = cms.untracked.bool(not options.noShm),
compression = cms.untracked.string(options.compression),
+ Retry = cms.VPSet(
+ cms.PSet(
+ retryType = cms.string('RetrySameServerAction'),
+ allowedTries = cms.untracked.uint32(options.tries)
+ )
+ )
)
)
)