Skip to content

Commit a02bf4f

Browse files
committed
Finalize code changes, Fix all tests, clean up code
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 3a96b9c commit a02bf4f

File tree

17 files changed

+133
-123
lines changed

17 files changed

+133
-123
lines changed

core/conversion/conversionctx/BUILD

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@ cc_library(
1919
"@tensorrt//:nvinfer",
2020
"@tensorrt//:nvinferplugin",
2121
"//core/util:prelude",
22-
#"//core/plugins:trtorch_plugins",
2322
] + select({
2423
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2524
"//conditions:default": ["@libtorch//:libtorch"],
2625
}),
27-
copts = ["-fpic"],
2826
)
2927

3028
load("@rules_pkg//:pkg.bzl", "pkg_tar")

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
#include "NvInferPluginUtils.h"
66

77
#include <utility>
8-
// #include "core/plugins/plugin_prelude.h"
98
#include "core/conversion/conversionctx/ConversionCtx.h"
10-
// #include "NvInferPlugin.h"
11-
// #include "NvInferPluginUtils.h"
129

1310
namespace trtorch {
1411
namespace core {
@@ -52,7 +49,6 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
5249
"[TRTorch Conversion Context] - ",
5350
util::logging::get_logger().get_reportable_severity(),
5451
util::logging::get_logger().get_is_colored_output_on()) {
55-
5652
// TODO: Support FP16 and FP32 from JIT information
5753
if (settings.device.gpu_id) {
5854
TRTORCH_CHECK(

core/conversion/converters/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ cc_library(
2929
cc_library(
3030
name = "converters",
3131
hdrs = [
32-
"converters.h",
32+
"converters.h"
3333
],
3434
srcs = [
3535
"NodeConverterRegistry.cpp",
@@ -67,7 +67,6 @@ cc_library(
6767
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
6868
"//conditions:default": ["@libtorch//:libtorch"],
6969
}),
70-
#copts = ["-fPIC"],
7170
alwayslink = True,
7271
)
7372

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,6 @@ namespace converters {
1010
namespace impl {
1111
namespace {
1212

13-
static inline at::Tensor repeat_if_defined(const at::Tensor& t, int64_t repeat) {
14-
if (t.defined()) {
15-
return t.repeat(repeat);
16-
}
17-
return t;
18-
}
19-
2013
auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern({
2114
R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
2215
Tensor? mean, Tensor? var,

core/conversion/converters/impl/interpolate.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void create_plugin(
5353

5454
fc.nbFields = f.size();
5555
fc.fields = f.data();
56-
auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "");
56+
auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "trtorch");
5757
auto interpolate_plugin = creator->createPlugin(name, &fc);
5858

5959
auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *interpolate_plugin);

core/conversion/converters/impl/normalize.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
#include "NvInferRuntimeCommon.h"
33
#include "core/conversion/converters/converters.h"
44
#include "core/util/prelude.h"
5-
// #include "core/plugins/impl/normalize_plugin.h"
6-
// #include "core/plugins/plugin_prelude.h"
75
#include "torch/torch.h"
86

97
namespace trtorch {
@@ -42,7 +40,8 @@ void create_plugin(
4240
TRTORCH_THROW_ERROR("Axis of normalization layer cannot exceed input rank");
4341
}
4442
}
45-
auto creator = getPluginRegistry()->getPluginCreator("NormalizePlugintrtorch", "1", "");
43+
44+
auto creator = getPluginRegistry()->getPluginCreator("NormalizePlugin", "1", "trtorch");
4645
auto plugin = creator->createPlugin(name, &fc);
4746
auto normalize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
4847
TRTORCH_CHECK(normalize_layer, "Unable to create normalization plugin from node" << *n);

core/conversion/converters/impl/pooling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ auto pooling_registrations TRTORCH_UNUSED =
313313
auto out_shape = in_shape;
314314
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
315315

316-
auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "");
316+
auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "trtorch");
317317

318318
// Configure the plugin fields
319319
nvinfer1::PluginFieldCollection fc;

core/plugins/BUILD

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@ config_setting(
1010
cc_library(
1111
name = "trtorch_plugins",
1212
hdrs = [
13-
"plugins.h",
1413
"impl/interpolate_plugin.h",
1514
"impl/normalize_plugin.h",
15+
"plugins.h",
16+
1617
],
1718
srcs = [
18-
"register_plugins.cpp",
1919
"impl/interpolate_plugin.cpp",
2020
"impl/normalize_plugin.cpp",
21+
"register_plugins.cpp",
2122
],
2223
deps = [
2324
"@tensorrt//:nvinfer",
@@ -28,6 +29,12 @@ cc_library(
2829
"//conditions:default": ["@libtorch//:libtorch"],
2930
}),
3031
alwayslink = True,
32+
copts = [
33+
"-pthread"
34+
],
35+
linkopts = [
36+
"-lpthread",
37+
]
3138
)
3239

3340
load("@rules_pkg//:pkg.bzl", "pkg_tar")

core/plugins/impl/interpolate_plugin.cpp

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
#include "core/plugins/plugins.h"
33
#include "core/util/prelude.h"
44

5-
using namespace nvinfer1;
6-
using namespace trtorch::core;
7-
// namespace trtorch {
8-
// namespace core {
9-
// namespace plugins {
5+
namespace trtorch {
6+
namespace core {
7+
namespace plugins {
8+
namespace impl {
109

1110
/*
1211
* InterpolatePlugin class implementations
@@ -119,7 +118,7 @@ const char* InterpolatePlugin::getPluginVersion() const {
119118
}
120119

121120
const char* InterpolatePlugin::getPluginNamespace() const {
122-
return "";
121+
return "trtorch";
123122
}
124123

125124
nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone() const {
@@ -163,7 +162,7 @@ nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
163162

164163
nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs)
165164
const {
166-
return DataType::kFLOAT;
165+
return nvinfer1::DataType::kFLOAT;
167166
}
168167

169168
int InterpolatePlugin::initialize() {
@@ -216,14 +215,14 @@ bool InterpolatePlugin::supportsFormatCombination(
216215
TRTORCH_ASSERT(nbInputs == 1, "Expected a single tensor as input to interpolate plugin");
217216
TRTORCH_ASSERT(nbOutputs == 1, "Expected a single tensor as output to interpolate plugin");
218217

219-
const PluginTensorDesc& in = inOut[0];
218+
const nvinfer1::PluginTensorDesc& in = inOut[0];
220219

221220
if (pos == 0) {
222221
return (in.type == nvinfer1::DataType::kFLOAT) && (in.format == nvinfer1::TensorFormat::kLINEAR);
223222
}
224223

225224
// pos == 1, accessing information about output tensor
226-
const PluginTensorDesc& out = inOut[1];
225+
const nvinfer1::PluginTensorDesc& out = inOut[1];
227226

228227
return (in.type == out.type) && (in.format == out.format);
229228
}
@@ -233,7 +232,7 @@ void InterpolatePlugin::configurePlugin(
233232
int nbInputs,
234233
const nvinfer1::DynamicPluginTensorDesc* out,
235234
int nbOutputs) {
236-
dtype_ = DataType::kFLOAT;
235+
dtype_ = nvinfer1::DataType::kFLOAT;
237236
}
238237

239238
size_t InterpolatePlugin::getWorkspaceSize(
@@ -344,20 +343,20 @@ int InterpolatePlugin::enqueue(
344343
*/
345344

346345
InterpolatePluginCreator::InterpolatePluginCreator() {
347-
mPluginAttributes.emplace_back(PluginField("in_shape", nullptr, PluginFieldType::kINT32, 1));
348-
mPluginAttributes.emplace_back(PluginField("out_shape", nullptr, PluginFieldType::kINT32, 1));
349-
mPluginAttributes.emplace_back(PluginField("out_size", nullptr, PluginFieldType::kINT32, 1));
350-
mPluginAttributes.emplace_back(PluginField("scales", nullptr, PluginFieldType::kFLOAT32, 1));
351-
mPluginAttributes.emplace_back(PluginField("mode", nullptr, PluginFieldType::kCHAR, 1));
352-
mPluginAttributes.emplace_back(PluginField("align_corners", nullptr, PluginFieldType::kINT32, 1));
353-
mPluginAttributes.emplace_back(PluginField("use_scales", nullptr, PluginFieldType::kINT32, 1));
346+
mPluginAttributes.emplace_back(nvinfer1::PluginField("in_shape", nullptr, nvinfer1::PluginFieldType::kINT32, 1));
347+
mPluginAttributes.emplace_back(nvinfer1::PluginField("out_shape", nullptr, nvinfer1::PluginFieldType::kINT32, 1));
348+
mPluginAttributes.emplace_back(nvinfer1::PluginField("out_size", nullptr, nvinfer1::PluginFieldType::kINT32, 1));
349+
mPluginAttributes.emplace_back(nvinfer1::PluginField("scales", nullptr, nvinfer1::PluginFieldType::kFLOAT32, 1));
350+
mPluginAttributes.emplace_back(nvinfer1::PluginField("mode", nullptr, nvinfer1::PluginFieldType::kCHAR, 1));
351+
mPluginAttributes.emplace_back(nvinfer1::PluginField("align_corners", nullptr, nvinfer1::PluginFieldType::kINT32, 1));
352+
mPluginAttributes.emplace_back(nvinfer1::PluginField("use_scales", nullptr, nvinfer1::PluginFieldType::kINT32, 1));
354353

355354
mFC.nbFields = mPluginAttributes.size();
356355
mFC.fields = mPluginAttributes.data();
357356
}
358357

359358
const char* InterpolatePluginCreator::getPluginNamespace() const {
360-
return "";
359+
return "trtorch";
361360
}
362361

363362
const char* InterpolatePluginCreator::getPluginName() const {
@@ -418,8 +417,7 @@ const nvinfer1::PluginFieldCollection* InterpolatePluginCreator::getFieldNames()
418417
return nullptr;
419418
}
420419

421-
REGISTER_TRTORCH_PLUGIN(InterpolatePluginCreator);
422-
423-
// } // namespace plugins
424-
// } // namespace core
425-
// } // namespace trtorch
420+
} // namespace impl
421+
} // namespace plugins
422+
} // namespace core
423+
} // namespace trtorch

core/plugins/impl/interpolate_plugin.h

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@
1313
#include "core/util/prelude.h"
1414
#include "torch/torch.h"
1515

16-
using namespace nvinfer1;
17-
//
18-
// namespace trtorch {
19-
// namespace core {
20-
// namespace plugins {
16+
namespace trtorch {
17+
namespace core {
18+
namespace plugins {
19+
namespace impl {
2120

2221
class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
2322
private:
2423
at::TensorOptions tensor_options_;
25-
DataType dtype_;
24+
nvinfer1::DataType dtype_;
2625

2726
std::vector<int64_t> in_shape_;
2827
std::vector<int64_t> out_shape_;
@@ -121,8 +120,8 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
121120
class InterpolatePluginCreator : public nvinfer1::IPluginCreator {
122121
private:
123122
std::string name_;
124-
std::vector<PluginField> mPluginAttributes;
125-
PluginFieldCollection mFC;
123+
std::vector<nvinfer1::PluginField> mPluginAttributes;
124+
nvinfer1::PluginFieldCollection mFC;
126125

127126
public:
128127
InterpolatePluginCreator();
@@ -137,21 +136,12 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator {
137136

138137
nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) override;
139138

140-
// InterpolatePlugin* createPlugin(
141-
// const char* name,
142-
// std::vector<int64_t> in_shape,
143-
// std::vector<int64_t> out_shape,
144-
// std::vector<int64_t> size,
145-
// std::vector<double> scales,
146-
// std::string mode,
147-
// bool align_corners,
148-
// bool use_scales);
149-
150139
nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
151140

152141
const nvinfer1::PluginFieldCollection* getFieldNames() override;
153142
};
154143

155-
// } // namespace plugins
156-
// } // namespace core
157-
// } // namespace trtorch
144+
} // namespace impl
145+
} // namespace plugins
146+
} // namespace core
147+
} // namespace trtorch

0 commit comments

Comments
 (0)