Conversation
There was a problem hiding this comment.
Cpp-linter Review
Used clang-format v18.1.8
Click here for the full clang-format patch
diff --git a/src/AMSlib/ml/surrogate.cpp b/src/AMSlib/ml/surrogate.cpp
index 512b816..8f3931f 100644
--- a/src/AMSlib/ml/surrogate.cpp
+++ b/src/AMSlib/ml/surrogate.cpp
@@ -17,2 +17 @@ using namespace ams;
-static std::string getDTypeAsString(
- torch::Dtype dtype)
+static std::string getDTypeAsString(torch::Dtype dtype)
@@ -32,2 +31 @@ static std::string getDTypeAsString(
-static std::string getAMSDTypeAsString(
- AMSDType dType)
+static std::string getAMSDTypeAsString(AMSDType dType)
@@ -42,2 +40 @@ static std::string getAMSDTypeAsString(
-static std::string getAMSResourceTypeAsString(
- AMSResourceType res)
+static std::string getAMSResourceTypeAsString(AMSResourceType res)
@@ -53,5 +50,2 @@ static std::string getAMSResourceTypeAsString(
-SurrogateModel::SurrogateModel(
- std::string& model_path,
- bool isDeltaUQ)
- : _model_path(model_path),
- _is_DeltaUQ(isDeltaUQ)
+SurrogateModel::SurrogateModel(std::string& model_path, bool isDeltaUQ)
+ : _model_path(model_path), _is_DeltaUQ(isDeltaUQ)
@@ -60,2 +54 @@ SurrogateModel::SurrogateModel(
- std::experimental::filesystem::path Path(
- model_path);
+ std::experimental::filesystem::path Path(model_path);
@@ -64,2 +57 @@ SurrogateModel::SurrogateModel(
- if (!std::experimental::filesystem::exists(
- Path, ec)) {
+ if (!std::experimental::filesystem::exists(Path, ec)) {
@@ -75,2 +67 @@ SurrogateModel::SurrogateModel(
- printf("Error opening %s\n",
- model_path.c_str());
+ printf("Error opening %s\n", model_path.c_str());
@@ -79,2 +70 @@ SurrogateModel::SurrogateModel(
- auto method_ptr =
- module.find_method("get_ams_info");
+ auto method_ptr = module.find_method("get_ams_info");
@@ -88,2 +78 @@ SurrogateModel::SurrogateModel(
- torch::IValue meta_ivalue =
- module.run_method("get_ams_info");
+ torch::IValue meta_ivalue = module.run_method("get_ams_info");
@@ -94,2 +83 @@ SurrogateModel::SurrogateModel(
- std::string value =
- item.value().toStringRef();
+ std::string value = item.value().toStringRef();
@@ -97,2 +85 @@ SurrogateModel::SurrogateModel(
- std::tie(model_dtype, torch_dtype) =
- convertModelDataType(value);
+ std::tie(model_dtype, torch_dtype) = convertModelDataType(value);
@@ -100,2 +87 @@ SurrogateModel::SurrogateModel(
- std::tie(model_device, torch_device) =
- convertModelResourceType(value);
+ std::tie(model_device, torch_device) = convertModelResourceType(value);
@@ -105,6 +91,4 @@ SurrogateModel::SurrogateModel(
- CFATAL(
- SurrogateModel,
- model_dtype == ams::AMS_UNKNOWN_TYPE ||
- model_device ==
- ams::AMSResourceType::AMS_UNKNOWN,
- "Model has unknown datatype or device");
+ CFATAL(SurrogateModel,
+ model_dtype == ams::AMS_UNKNOWN_TYPE ||
+ model_device == ams::AMSResourceType::AMS_UNKNOWN,
+ "Model has unknown datatype or device");
@@ -115,2 +99 @@ SurrogateModel::SurrogateModel(
- getAMSResourceTypeAsString(model_device)
- .c_str());
+ getAMSResourceTypeAsString(model_device).c_str());
@@ -120,2 +103 @@ SurrogateModel::SurrogateModel(
-std::tuple<ams::AMSDType, torch::Dtype>
-SurrogateModel::getModelDataType() const
+std::tuple<ams::AMSDType, torch::Dtype> SurrogateModel::getModelDataType() const
@@ -123,2 +105 @@ SurrogateModel::getModelDataType() const
- return std::make_tuple(model_dtype,
- torch_dtype);
+ return std::make_tuple(model_dtype, torch_dtype);
@@ -127,2 +108,2 @@ SurrogateModel::getModelDataType() const
-std::tuple<AMSResourceType, torch::DeviceType>
-SurrogateModel::getModelResourceType() const
+std::tuple<AMSResourceType, torch::DeviceType> SurrogateModel::
+ getModelResourceType() const
@@ -130,2 +111 @@ SurrogateModel::getModelResourceType() const
- return std::make_tuple(model_device,
- torch_device);
+ return std::make_tuple(model_device, torch_device);
@@ -134,3 +114,2 @@ SurrogateModel::getModelResourceType() const
-std::tuple<AMSResourceType, torch::DeviceType>
-SurrogateModel::convertModelResourceType(
- std::string& value)
+std::tuple<AMSResourceType, torch::DeviceType> SurrogateModel::
+ convertModelResourceType(std::string& value)
@@ -140,2 +119 @@ SurrogateModel::convertModelResourceType(
- return std::make_tuple(AMS_HOST,
- c10::DeviceType::CPU);
+ return std::make_tuple(AMS_HOST, c10::DeviceType::CPU);
@@ -143,2 +121 @@ SurrogateModel::convertModelResourceType(
- return std::make_tuple(AMS_DEVICE,
- c10::DeviceType::CUDA);
+ return std::make_tuple(AMS_DEVICE, c10::DeviceType::CUDA);
@@ -146,2 +123 @@ SurrogateModel::convertModelResourceType(
- return std::make_tuple(AMS_DEVICE,
- c10::DeviceType::CUDA);
+ return std::make_tuple(AMS_DEVICE, c10::DeviceType::CUDA);
@@ -154,4 +130,2 @@ SurrogateModel::convertModelResourceType(
- return std::make_tuple(
- AMS_UNKNOWN,
- c10::DeviceType::
- COMPILE_TIME_MAX_DEVICE_TYPES);
+ return std::make_tuple(AMS_UNKNOWN,
+ c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
@@ -160,2 +134 @@ SurrogateModel::convertModelResourceType(
-std::tuple<AMSDType, torch::Dtype>
-SurrogateModel::convertModelDataType(
+std::tuple<AMSDType, torch::Dtype> SurrogateModel::convertModelDataType(
@@ -167,2 +140 @@ SurrogateModel::convertModelDataType(
- return std::make_tuple(AMSDType::AMS_SINGLE,
- at ::kFloat);
+ return std::make_tuple(AMSDType::AMS_SINGLE, at ::kFloat);
@@ -170,2 +142 @@ SurrogateModel::convertModelDataType(
- return std::make_tuple(AMSDType::AMS_DOUBLE,
- at ::kDouble);
+ return std::make_tuple(AMSDType::AMS_DOUBLE, at ::kDouble);
@@ -174,3 +145 @@ SurrogateModel::convertModelDataType(
- FATAL(Surrogate,
- "unknown data type of model %s",
- type.c_str());
+ FATAL(Surrogate, "unknown data type of model %s", type.c_str());
@@ -182,2 +151 @@ SurrogateModel::convertModelDataType(
-std::tuple<torch::Tensor, torch::Tensor>
-SurrogateModel::_computeDetlaUQ(
+std::tuple<torch::Tensor, torch::Tensor> SurrogateModel::_computeDetlaUQ(
@@ -188,14 +156,11 @@ SurrogateModel::_computeDetlaUQ(
- at::Tensor output_mean_tensor =
- deltaUQTuple.toTuple()
- ->elements()[0]
- .toTensor()
- .set_requires_grad(false)
- .detach();
- at::Tensor output_stdev_tensor =
- deltaUQTuple.toTuple()
- ->elements()[1]
- .toTensor()
- .set_requires_grad(false)
- .detach();
- auto outer_dim =
- output_stdev_tensor.sizes().size() - 1;
+ at::Tensor output_mean_tensor = deltaUQTuple.toTuple()
+ ->elements()[0]
+ .toTensor()
+ .set_requires_grad(false)
+ .detach();
+ at::Tensor output_stdev_tensor = deltaUQTuple.toTuple()
+ ->elements()[1]
+ .toTensor()
+ .set_requires_grad(false)
+ .detach();
+ auto outer_dim = output_stdev_tensor.sizes().size() - 1;
@@ -204,2 +169 @@ SurrogateModel::_computeDetlaUQ(
- throw std::runtime_error(
- "Invalid DELTA_UQ policy");
+ throw std::runtime_error("Invalid DELTA_UQ policy");
@@ -208,2 +172 @@ SurrogateModel::_computeDetlaUQ(
- auto mean =
- output_stdev_tensor.mean(outer_dim);
+ auto mean = output_stdev_tensor.mean(outer_dim);
@@ -211,5 +174,2 @@ SurrogateModel::_computeDetlaUQ(
- return std::make_tuple(
- std::move(output_mean_tensor),
- std::move(predicate));
- } else if (policy ==
- AMSUQPolicy::AMS_DELTAUQ_MAX) {
+ return std::make_tuple(std::move(output_mean_tensor), std::move(predicate));
+ } else if (policy == AMSUQPolicy::AMS_DELTAUQ_MAX) {
@@ -219,3 +179 @@ SurrogateModel::_computeDetlaUQ(
- return std::make_tuple(
- std::move(output_mean_tensor),
- std::move(predicate));
+ return std::make_tuple(std::move(output_mean_tensor), std::move(predicate));
@@ -223,2 +181 @@ SurrogateModel::_computeDetlaUQ(
- throw std::runtime_error(
- "Invalid DELTA_UQ policy");
+ throw std::runtime_error("Invalid DELTA_UQ policy");
@@ -228,4 +185,4 @@ SurrogateModel::_computeDetlaUQ(
-std::tuple<torch::Tensor, torch::Tensor>
-SurrogateModel::_evaluate(torch::Tensor& inputs,
- AMSUQPolicy policy,
- float threshold)
+std::tuple<torch::Tensor, torch::Tensor> SurrogateModel::_evaluate(
+ torch::Tensor& inputs,
+ AMSUQPolicy policy,
+ float threshold)
@@ -237,5 +194,2 @@ SurrogateModel::_evaluate(torch::Tensor& inputs,
- getDTypeAsString(
- torch::typeMetaToScalarType(
- inputs.dtype())) +
- " and model is " +
- getDTypeAsString(torch_dtype));
+ getDTypeAsString(torch::typeMetaToScalarType(inputs.dtype())) +
+ " and model is " + getDTypeAsString(torch_dtype));
@@ -246,3 +200 @@ SurrogateModel::_evaluate(torch::Tensor& inputs,
- return _computeDetlaUQ(out,
- policy,
- threshold);
+ return _computeDetlaUQ(out, policy, threshold);
@@ -251,4 +203 @@ SurrogateModel::_evaluate(torch::Tensor& inputs,
- at::Tensor output_tensor =
- out.toTensor()
- .set_requires_grad(false)
- .detach();
+ at::Tensor output_tensor = out.toTensor().set_requires_grad(false).detach();
@@ -257,8 +206,3 @@ SurrogateModel::_evaluate(torch::Tensor& inputs,
- torch::zeros({output_tensor.sizes()[0], 1},
- torch::kBool);
- auto indices =
- torch::randperm(output_tensor.sizes()[0])
- .slice(0,
- 0,
- threshold *
- output_tensor.sizes()[0]);
+ torch::zeros({output_tensor.sizes()[0], 1}, torch::kBool);
+ auto indices = torch::randperm(output_tensor.sizes()[0])
+ .slice(0, 0, threshold * output_tensor.sizes()[0]);
@@ -268,2 +212 @@ SurrogateModel::_evaluate(torch::Tensor& inputs,
- return std::make_tuple(std::move(output_tensor),
- std::move(predicate));
+ return std::make_tuple(std::move(output_tensor), std::move(predicate));
@@ -273,2 +216 @@ SurrogateModel::_evaluate(torch::Tensor& inputs,
-std::tuple<torch::Tensor, torch::Tensor>
-SurrogateModel::evaluate(
+std::tuple<torch::Tensor, torch::Tensor> SurrogateModel::evaluate(
@@ -285,5 +227,2 @@ SurrogateModel::evaluate(
- torch::DeviceType InputDevice =
- Inputs[0].device().type();
- torch::Dtype InputDType =
- torch::typeMetaToScalarType(
- Inputs[0].dtype());
+ torch::DeviceType InputDevice = Inputs[0].device().type();
+ torch::Dtype InputDType = torch::typeMetaToScalarType(Inputs[0].dtype());
@@ -300,2 +239 @@ SurrogateModel::evaluate(
- if (InputDType !=
- torch::typeMetaToScalarType(In.dtype())) {
+ if (InputDType != torch::typeMetaToScalarType(In.dtype())) {
@@ -308,2 +246 @@ SurrogateModel::evaluate(
- c10::SmallVector<torch::Tensor> ConvertedInputs(
- Inputs.begin(), Inputs.end());
+ c10::SmallVector<torch::Tensor> ConvertedInputs(Inputs.begin(), Inputs.end());
@@ -312,7 +249,3 @@ SurrogateModel::evaluate(
- if (InputDevice != torch_device ||
- InputDType != torch_dtype) {
- for (int i = 0; i < ConvertedInputs.size();
- i++) {
- ConvertedInputs[i] =
- ConvertedInputs[i].to(torch_device,
- torch_dtype);
+ if (InputDevice != torch_device || InputDType != torch_dtype) {
+ for (int i = 0; i < ConvertedInputs.size(); i++) {
+ ConvertedInputs[i] = ConvertedInputs[i].to(torch_device, torch_dtype);
@@ -322,2 +255 @@ SurrogateModel::evaluate(
- auto ITensor =
- torch::cat(ConvertedInputs, CAxis);
+ auto ITensor = torch::cat(ConvertedInputs, CAxis);
@@ -328,2 +260 @@ SurrogateModel::evaluate(
- auto [OTensor, Predicate] =
- _evaluate(ITensor, policy, threshold);
+ auto [OTensor, Predicate] = _evaluate(ITensor, policy, threshold);
@@ -334,2 +265 @@ SurrogateModel::evaluate(
- return std::make_tuple(std::move(OTensor),
- std::move(Predicate));
+ return std::make_tuple(std::move(OTensor), std::move(Predicate));
@@ -339,3 +269 @@ SurrogateModel::evaluate(
-std::unordered_map<
- std::string,
- std::shared_ptr<SurrogateModel>>
+std::unordered_map<std::string, std::shared_ptr<SurrogateModel>>
Have any feedback or feature suggestions? Share it here.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Cpp-linter Review
Used clang-format v18.1.8
Only 6 out of 7 clang-format concerns fit within this pull request's diff.
Click here for the full clang-format patch
diff --git a/src/AMSlib/ml/surrogate.cpp b/src/AMSlib/ml/surrogate.cpp
index ffd325c..828496c 100644
--- a/src/AMSlib/ml/surrogate.cpp
+++ b/src/AMSlib/ml/surrogate.cpp
@@ -59,2 +59 @@ SurrogateModel::SurrogateModel(
- std::experimental::filesystem::path Path(
- model_path);
+ std::experimental::filesystem::path Path(model_path);
@@ -74,2 +73 @@ SurrogateModel::SurrogateModel(
- printf("Error opening %s\n",
- model_path.c_str());
+ printf("Error opening %s\n", model_path.c_str());
@@ -126,2 +124,2 @@ SurrogateModel::getModelDataType() const
-std::tuple<AMSResourceType, torch::DeviceType>
-SurrogateModel::getModelResourceType() const
+std::tuple<AMSResourceType, torch::DeviceType> SurrogateModel::
+ getModelResourceType() const
@@ -129,2 +127 @@ SurrogateModel::getModelResourceType() const
- return std::make_tuple(model_device,
- torch_device);
+ return std::make_tuple(model_device, torch_device);
@@ -187,12 +184,10 @@ SurrogateModel::_computeDetlaUQ(
- at::Tensor output_mean_tensor =
- deltaUQTuple.toTuple()
- ->elements()[0]
- .toTensor()
- .set_requires_grad(false)
- .detach();
- at::Tensor output_stdev_tensor =
- deltaUQTuple.toTuple()
- ->elements()[1]
- .toTensor()
- .set_requires_grad(false)
- .detach();
+ at::Tensor output_mean_tensor = deltaUQTuple.toTuple()
+ ->elements()[0]
+ .toTensor()
+ .set_requires_grad(false)
+ .detach();
+ at::Tensor output_stdev_tensor = deltaUQTuple.toTuple()
+ ->elements()[1]
+ .toTensor()
+ .set_requires_grad(false)
+ .detach();
@@ -222,2 +217 @@ SurrogateModel::_computeDetlaUQ(
- throw std::runtime_error(
- "Invalid DELTA_UQ policy");
+ throw std::runtime_error("Invalid DELTA_UQ policy");
@@ -236,5 +230,2 @@ SurrogateModel::_evaluate(torch::Tensor& inputs,
- getDTypeAsString(
- torch::typeMetaToScalarType(
- inputs.dtype())) +
- " and model is " +
- getDTypeAsString(torch_dtype));
+ getDTypeAsString(torch::typeMetaToScalarType(inputs.dtype())) +
+ " and model is " + getDTypeAsString(torch_dtype));
Have any feedback or feature suggestions? Share it here.
There was a problem hiding this comment.
Cpp-linter Review
Used clang-format v18.1.8
Only 5 out of 6 clang-format concerns fit within this pull request's diff.
Click here for the full clang-format patch
diff --git a/src/AMSlib/ml/surrogate.cpp b/src/AMSlib/ml/surrogate.cpp
index 4dc960e..036720e 100644
--- a/src/AMSlib/ml/surrogate.cpp
+++ b/src/AMSlib/ml/surrogate.cpp
@@ -58,2 +58 @@ SurrogateModel::SurrogateModel(
- std::experimental::filesystem::path Path(
- model_path);
+ std::experimental::filesystem::path Path(model_path);
@@ -73,2 +72 @@ SurrogateModel::SurrogateModel(
- printf("Error opening %s\n",
- model_path.c_str());
+ printf("Error opening %s\n", model_path.c_str());
@@ -125,2 +123,2 @@ SurrogateModel::getModelDataType() const
-std::tuple<AMSResourceType, torch::DeviceType>
-SurrogateModel::getModelResourceType() const
+std::tuple<AMSResourceType, torch::DeviceType> SurrogateModel::
+ getModelResourceType() const
@@ -128,2 +126 @@ SurrogateModel::getModelResourceType() const
- return std::make_tuple(model_device,
- torch_device);
+ return std::make_tuple(model_device, torch_device);
@@ -192,6 +189,5 @@ SurrogateModel::_computeDetlaUQ(
- at::Tensor output_stdev_tensor =
- deltaUQTuple.toTuple()
- ->elements()[1]
- .toTensor()
- .set_requires_grad(false)
- .detach();
+ at::Tensor output_stdev_tensor = deltaUQTuple.toTuple()
+ ->elements()[1]
+ .toTensor()
+ .set_requires_grad(false)
+ .detach();
@@ -235,5 +231,2 @@ SurrogateModel::_evaluate(torch::Tensor& inputs,
- getDTypeAsString(
- torch::typeMetaToScalarType(
- inputs.dtype())) +
- " and model is " +
- getDTypeAsString(torch_dtype));
+ getDTypeAsString(torch::typeMetaToScalarType(inputs.dtype())) +
+ " and model is " + getDTypeAsString(torch_dtype));
Have any feedback or feature suggestions? Share it here.
There was a problem hiding this comment.
Cpp-linter Review
Used clang-format v18.1.8
Only 6 out of 7 clang-format concerns fit within this pull request's diff.
Click here for the full clang-format patch
diff --git a/src/AMSlib/ml/surrogate.cpp b/src/AMSlib/ml/surrogate.cpp
index a272c5a..c37beb0 100644
--- a/src/AMSlib/ml/surrogate.cpp
+++ b/src/AMSlib/ml/surrogate.cpp
@@ -61,2 +61 @@ SurrogateModel::SurrogateModel(
- if (!std::experimental::filesystem::exists(
- Path, ec)) {
+ if (!std::experimental::filesystem::exists(Path, ec)) {
@@ -72,2 +71 @@ SurrogateModel::SurrogateModel(
- printf("Error opening %s\n",
- model_path.c_str());
+ printf("Error opening %s\n", model_path.c_str());
@@ -127,2 +125 @@ SurrogateModel::getModelResourceType() const
- return std::make_tuple(model_device,
- torch_device);
+ return std::make_tuple(model_device, torch_device);
@@ -191,6 +188,5 @@ SurrogateModel::_computeDetlaUQ(
- at::Tensor output_stdev_tensor =
- deltaUQTuple.toTuple()
- ->elements()[1]
- .toTensor()
- .set_requires_grad(false)
- .detach();
+ at::Tensor output_stdev_tensor = deltaUQTuple.toTuple()
+ ->elements()[1]
+ .toTensor()
+ .set_requires_grad(false)
+ .detach();
@@ -225,4 +221,4 @@ SurrogateModel::_computeDetlaUQ(
-std::tuple<torch::Tensor, torch::Tensor>
-SurrogateModel::_evaluate(torch::Tensor& inputs,
- AMSUQPolicy policy,
- float threshold)
+std::tuple<torch::Tensor, torch::Tensor> SurrogateModel::_evaluate(
+ torch::Tensor& inputs,
+ AMSUQPolicy policy,
+ float threshold)
@@ -234,5 +230,2 @@ SurrogateModel::_evaluate(torch::Tensor& inputs,
- getDTypeAsString(
- torch::typeMetaToScalarType(
- inputs.dtype())) +
- " and model is " +
- getDTypeAsString(torch_dtype));
+ getDTypeAsString(torch::typeMetaToScalarType(inputs.dtype())) +
+ " and model is " + getDTypeAsString(torch_dtype));
@@ -243,3 +236 @@ SurrogateModel::_evaluate(torch::Tensor& inputs,
- return _computeDetlaUQ(out,
- policy,
- threshold);
+ return _computeDetlaUQ(out, policy, threshold);
Have any feedback or feature suggestions? Share it here.
There was a problem hiding this comment.
Cpp-linter Review
Used clang-format v18.1.8
Only 5 out of 7 clang-format concerns fit within this pull request's diff.
Click here for the full clang-format patch
diff --git a/src/AMSlib/ml/surrogate.cpp b/src/AMSlib/ml/surrogate.cpp
index 018de96..42db58b 100644
--- a/src/AMSlib/ml/surrogate.cpp
+++ b/src/AMSlib/ml/surrogate.cpp
@@ -58,2 +58 @@ SurrogateModel::SurrogateModel(std::string& model_path, bool isDeltaUQ)
- if (!std::experimental::filesystem::exists(
- Path, ec)) {
+ if (!std::experimental::filesystem::exists(Path, ec)) {
@@ -73,2 +72 @@ SurrogateModel::SurrogateModel(std::string& model_path, bool isDeltaUQ)
- auto method_ptr =
- module.find_method("get_ams_info");
+ auto method_ptr = module.find_method("get_ams_info");
@@ -128,3 +126,2 @@ SurrogateModel::getModelResourceType() const
-std::tuple<AMSResourceType, torch::DeviceType>
-SurrogateModel::convertModelResourceType(
- std::string& value)
+std::tuple<AMSResourceType, torch::DeviceType> SurrogateModel::
+ convertModelResourceType(std::string& value)
@@ -188,8 +185,6 @@ SurrogateModel::_computeDetlaUQ(
- at::Tensor output_stdev_tensor =
- deltaUQTuple.toTuple()
- ->elements()[1]
- .toTensor()
- .set_requires_grad(false)
- .detach();
- auto outer_dim =
- output_stdev_tensor.sizes().size() - 1;
+ at::Tensor output_stdev_tensor = deltaUQTuple.toTuple()
+ ->elements()[1]
+ .toTensor()
+ .set_requires_grad(false)
+ .detach();
+ auto outer_dim = output_stdev_tensor.sizes().size() - 1;
@@ -222,4 +217,4 @@ SurrogateModel::_computeDetlaUQ(
-std::tuple<torch::Tensor, torch::Tensor>
-SurrogateModel::_evaluate(torch::Tensor& inputs,
- AMSUQPolicy policy,
- float threshold)
+std::tuple<torch::Tensor, torch::Tensor> SurrogateModel::_evaluate(
+ torch::Tensor& inputs,
+ AMSUQPolicy policy,
+ float threshold)
@@ -231,5 +226,2 @@ SurrogateModel::_evaluate(torch::Tensor& inputs,
- getDTypeAsString(
- torch::typeMetaToScalarType(
- inputs.dtype())) +
- " and model is " +
- getDTypeAsString(torch_dtype));
+ getDTypeAsString(torch::typeMetaToScalarType(inputs.dtype())) +
+ " and model is " + getDTypeAsString(torch_dtype));
@@ -240,3 +232 @@ SurrogateModel::_evaluate(torch::Tensor& inputs,
- return _computeDetlaUQ(out,
- policy,
- threshold);
+ return _computeDetlaUQ(out, policy, threshold);
Have any feedback or feature suggestions? Share it here.
Fixes #136