Skip to content

Commit 0745c5f

Browse files
committed
refactor(//cpp/bin/torchtrtc): Refactor the CLI to make it easier to
extend later Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 10957eb commit 0745c5f

File tree

9 files changed

+476
-302
lines changed

9 files changed

+476
-302
lines changed

cpp/bin/torchtrtc/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@ config_setting(
1010
cc_binary(
1111
name = "torchtrtc",
1212
srcs = [
13+
"accuracy.h",
14+
"accuracy.cpp",
15+
"fileio.h",
16+
"fileio.cpp",
17+
"luts.h",
1318
"main.cpp",
19+
"parser_util.h",
20+
"parser_util.cpp"
1421
],
1522
deps = [
1623
"//third_party/args",

cpp/bin/torchtrtc/accuracy.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include "accuracy.h"
2+
3+
#include "torch_tensorrt/logging.h"
4+
#include "torch_tensorrt/torch_tensorrt.h"
5+
6+
namespace torchtrtc {
7+
namespace accuracy {
8+
9+
bool check_rtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs, float threshold) {
10+
double maxValue = 0.0;
11+
for (auto& tensor : inputs) {
12+
maxValue = fmax(tensor.abs().max().item<float>(), maxValue);
13+
}
14+
torchtrt::logging::log(
15+
torchtrt::logging::Level::kDEBUG,
16+
std::string("Max Difference: ") + std::to_string(diff.abs().max().item<float>()));
17+
torchtrt::logging::log(
18+
torchtrt::logging::Level::kDEBUG, std::string("Acceptable Threshold: ") + std::to_string(threshold));
19+
return diff.abs().max().item<float>() <= threshold * maxValue;
20+
}
21+
22+
bool almost_equal(const at::Tensor& a, const at::Tensor& b, float threshold) {
23+
return check_rtol(a - b, {a, b}, threshold);
24+
}
25+
26+
} // namespace accuracy
27+
} // namespace torchtrtc

cpp/bin/torchtrtc/accuracy.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#pragma once
2+
3+
#include <stdlib.h>
4+
#include <iostream>
5+
#include <sstream>
6+
#include <vector>
7+
8+
#include "torch/script.h"
9+
#include "torch/torch.h"
10+
11+
namespace torchtrtc {
12+
namespace accuracy {
13+
14+
bool check_rtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs, float threshold);
15+
bool almost_equal(const at::Tensor& a, const at::Tensor& b, float threshold);
16+
17+
} // namespace accuracy
18+
} // namespace torchtrtc

cpp/bin/torchtrtc/fileio.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#include "fileio.h"
2+
3+
namespace torchtrtc {
4+
namespace fileio {
5+
6+
std::string read_buf(std::string const& path) {
7+
std::string buf;
8+
std::ifstream stream(path.c_str(), std::ios::binary);
9+
10+
if (stream) {
11+
stream >> std::noskipws;
12+
std::copy(std::istream_iterator<char>(stream), std::istream_iterator<char>(), std::back_inserter(buf));
13+
}
14+
15+
return buf;
16+
}
17+
18+
std::string get_cwd() {
19+
char buff[FILENAME_MAX]; // create string buffer to hold path
20+
if (getcwd(buff, FILENAME_MAX)) {
21+
std::string current_working_dir(buff);
22+
return current_working_dir;
23+
} else {
24+
torchtrt::logging::log(torchtrt::logging::Level::kERROR, "Unable to get current directory");
25+
exit(1);
26+
}
27+
}
28+
29+
std::string real_path(std::string path) {
30+
auto abs_path = path;
31+
char real_path_c[PATH_MAX];
32+
char* res = realpath(abs_path.c_str(), real_path_c);
33+
if (res) {
34+
return std::string(real_path_c);
35+
} else {
36+
torchtrt::logging::log(torchtrt::logging::Level::kERROR, std::string("Unable to find file ") + abs_path);
37+
exit(1);
38+
}
39+
}
40+
41+
std::string resolve_path(std::string path) {
42+
auto rpath = path;
43+
if (!(rpath.rfind("/", 0) == 0)) {
44+
rpath = get_cwd() + '/' + rpath;
45+
}
46+
return rpath;
47+
}
48+
49+
} // namespace fileio
50+
} // namespace torchtrtc

cpp/bin/torchtrtc/fileio.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#pragma once
2+
#include <stdlib.h>
3+
#include <iostream>
4+
#include <sstream>
5+
6+
#ifdef __linux__
7+
#include <linux/limits.h>
8+
#else
9+
#define PATH_MAX 260
10+
#endif
11+
12+
#if defined(_WIN32)
13+
#include <direct.h>
14+
#define getcwd _getcwd
15+
#define realpath(N, R) _fullpath((R), (N), PATH_MAX)
16+
#else
17+
#include <unistd.h>
18+
#endif
19+
20+
#include "NvInfer.h"
21+
#include "third_party/args/args.hpp"
22+
#include "torch/script.h"
23+
#include "torch/torch.h"
24+
25+
#include "torch_tensorrt/logging.h"
26+
#include "torch_tensorrt/ptq.h"
27+
#include "torch_tensorrt/torch_tensorrt.h"
28+
29+
namespace torchtrtc {
30+
namespace fileio {
31+
32+
std::string read_buf(std::string const& path);
33+
std::string get_cwd();
34+
std::string real_path(std::string path);
35+
std::string resolve_path(std::string path);
36+
37+
} // namespace fileio
38+
} // namespace torchtrtc

cpp/bin/torchtrtc/luts.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#pragma once
2+
3+
#include "NvInfer.h"
4+
#include "third_party/args/args.hpp"
5+
#include "torch/script.h"
6+
#include "torch/torch.h"
7+
8+
namespace torchtrtc {
9+
namespace luts {
10+
11+
at::ScalarType to_torch_dtype(torchtrt::DataType dtype) {
12+
switch (dtype) {
13+
case torchtrt::DataType::kHalf:
14+
return at::kHalf;
15+
case torchtrt::DataType::kChar:
16+
return at::kChar;
17+
case torchtrt::DataType::kInt:
18+
return at::kInt;
19+
case torchtrt::DataType::kBool:
20+
return at::kBool;
21+
case torchtrt::DataType::kFloat:
22+
default:
23+
return at::kFloat;
24+
}
25+
}
26+
27+
const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_map() {
28+
static const std::unordered_map<nvinfer1::DataType, at::ScalarType> trt_at_type_map = {
29+
{nvinfer1::DataType::kFLOAT, at::kFloat},
30+
{nvinfer1::DataType::kHALF, at::kHalf},
31+
{nvinfer1::DataType::kINT32, at::kInt},
32+
{nvinfer1::DataType::kINT8, at::kChar},
33+
{nvinfer1::DataType::kBOOL, at::kBool},
34+
};
35+
return trt_at_type_map;
36+
}
37+
38+
} // namespace luts
39+
} // namespace torchtrtc

0 commit comments

Comments
 (0)