Skip to content

Commit ce478ac

Browse files
authored
Merge pull request #903 from NVIDIA/anuragd/adapt_new_threshold_criteria
fix: Considering rtol and atol in threshold comparison for floating p…
2 parents 6a12934 + 11ac11d commit ce478ac

File tree

2 files changed

+10
-13
lines changed

2 files changed

+10
-13
lines changed

tests/util/util.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,18 @@ namespace torch_tensorrt {
55
namespace tests {
66
namespace util {
77

8-
bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs, float threshold) {
9-
double maxValue = 0.0;
10-
for (auto& tensor : inputs) {
11-
maxValue = fmax(tensor.abs().max().item<float>(), maxValue);
12-
}
13-
std::cout << "Max Difference: " << diff.abs().max().item<float>() << std::endl;
14-
std::cout << "Acceptable Threshold: " << threshold << std::endl;
15-
return diff.abs().max().item<float>() <= threshold * maxValue;
16-
}
17-
18-
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold) {
8+
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold, float atol = 1e-8, float rtol = 1e-5) {
199
LOG_GRAPH(a << std::endl << b << std::endl);
2010
auto a_float = a.toType(at::kFloat);
2111
auto b_float = b.toType(at::kFloat);
22-
return checkRtol(a_float - b_float, {a_float, b_float}, threshold);
12+
13+
auto diff = a_float - b_float;
14+
auto result = diff.abs().max().item<float>() - (atol + rtol * b.abs().max().item<float>());
15+
16+
std::cout << "Max Difference: " << result << std::endl;
17+
std::cout << "Acceptable Threshold: " << threshold << std::endl;
18+
19+
return result <= threshold;
2320
}
2421

2522
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) {

tests/util/util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace torch_tensorrt {
1111
namespace tests {
1212
namespace util {
1313

14-
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold);
14+
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold, float atol = 1e-8, float rtol = 1e-5);
1515

1616
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
1717

0 commit comments

Comments
 (0)