Skip to content

Commit b59f9c7

Browse files
authored
Merge pull request #347 from NVIDIA/dump_supported_ops
Dump supported ops
2 parents 1c9dfe2 + 0a45bdf commit b59f9c7

File tree

10 files changed

+430
-0
lines changed

10 files changed

+430
-0
lines changed

core/conversion/converters/NodeConverterRegistry.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class NodeConverterRegistry {
4747
public:
4848
bool RegisterConverter(torch::jit::FunctionSchema* signature, OpConverter& converter) {
4949
LOG_DEBUG("Registering converter for " << canonical_schema_string(*signature));
50+
registered_converter_schemas_.insert(c10::toString(*signature));
5051
auto name = signature->operator_name();
5152
auto iter = converter_lut_.find(name);
5253
if (iter != converter_lut_.end()) {
@@ -83,8 +84,16 @@ class NodeConverterRegistry {
8384
}
8485
}
8586

87+
std::vector<std::string> GetRegisteredConverterList() {
88+
std::vector<std::string> converter_list;
89+
std::copy(
90+
registered_converter_schemas_.begin(), registered_converter_schemas_.end(), std::back_inserter(converter_list));
91+
return converter_list;
92+
}
93+
8694
private:
8795
ConverterLUT converter_lut_;
96+
std::set<std::string> registered_converter_schemas_;
8897
};
8998

9099
NodeConverterRegistry& get_converter_registry() {
@@ -115,6 +124,10 @@ bool node_is_convertable(const torch::jit::Node* n) {
115124
return get_converter_registry().Convertable(n);
116125
}
117126

127+
std::vector<std::string> get_converter_list() {
128+
return get_converter_registry().GetRegisteredConverterList();
129+
}
130+
118131
RegisterNodeConversionPatterns&& RegisterNodeConversionPatterns::pattern(ConversionPattern p) && {
119132
register_node_converter(std::move(p));
120133
return std::move(*this);

core/conversion/converters/converters.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class RegisterNodeConversionPatterns {
3939

4040
bool node_is_convertable(const torch::jit::Node* n);
4141
OpConverter get_node_converter_for(const torch::jit::FunctionSchema* signature);
42+
std::vector<std::string> get_converter_list();
4243

4344
} // namespace converters
4445
} // namespace conversion
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#include "NvInfer.h"
2+
#include "NvInferRuntimeCommon.h"
3+
#include "core/conversion/converters/converters.h"
4+
#include "core/util/prelude.h"
5+
#include "plugins/interpolate_plugin.h"
6+
#include "torch/torch.h"
7+
8+
namespace trtorch {
9+
namespace core {
10+
namespace conversion {
11+
namespace converters {
12+
namespace impl {
13+
namespace {
14+
15+
/*
16+
* Helper functions
17+
*/
18+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
19+
void create_plugin(
20+
ConversionCtx* ctx,
21+
const torch::jit::Node* n,
22+
nvinfer1::ITensor* in,
23+
const char* name,
24+
std::vector<int64_t> in_shape,
25+
std::vector<int64_t> out_shape,
26+
std::vector<int64_t> out_size,
27+
std::string mode) {
28+
LOG_WARNING("Interpolation layer will be run through ATen, not TensorRT. Performance may be lower than expected");
29+
30+
auto creator = new plugins::InterpolatePluginCreator();
31+
auto plugin = creator->createPlugin(name, in_shape, out_shape, out_size, mode, false);
32+
33+
auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
34+
TRTORCH_CHECK(resize_layer, "Unable to create interpolation plugin from node" << *n);
35+
36+
resize_layer->setName(util::node_info(n).c_str());
37+
38+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
39+
40+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
41+
}
42+
#endif
43+
44+
void resize_layer_size(
45+
ConversionCtx* ctx,
46+
const torch::jit::Node* n,
47+
nvinfer1::ITensor* in,
48+
std::vector<int64_t> out_shape,
49+
nvinfer1::ResizeMode mode,
50+
bool align_corners = false) {
51+
auto resize_layer = ctx->net->addResize(*in);
52+
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
53+
54+
resize_layer->setOutputDimensions(util::toDims(out_shape));
55+
resize_layer->setResizeMode(mode);
56+
resize_layer->setName(util::node_info(n).c_str());
57+
58+
// if interpolation mode is linear, align corners must have been set to true.
59+
// else, don't use align corners.
60+
if (mode == nvinfer1::ResizeMode::kLINEAR) {
61+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
62+
resize_layer->setAlignCorners(true);
63+
#else
64+
resize_layer->setAlignCorners(align_corners);
65+
#endif
66+
}
67+
68+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
69+
70+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
71+
}
72+
73+
/*
74+
* Interpolate Converter
75+
*/
76+
77+
auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
78+
{"aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> (Tensor))",
79+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
80+
auto in = args[0].ITensor();
81+
auto in_shape = util::toVec(in->getDimensions());
82+
bool align_corners = args[2].unwrapToBool();
83+
84+
// Case 1: user uses output size and not scales_h, scales_w
85+
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone() && args[4].IValue()->isNone()) {
86+
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
87+
88+
TRTORCH_ASSERT(
89+
out_size.size() == 2, "aten::upsample_bilinear2d input Tensor and output size dimension mismatch");
90+
91+
auto out_shape = in_shape;
92+
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
93+
94+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
95+
if (!align_corners) {
96+
// align_corners not supported in TensorRT, create plugin and
97+
// run layer through PyTorch
98+
create_plugin(ctx, n, in, "bilinear2d", in_shape, out_shape, out_size, std::string("bilinear"));
99+
} else {
100+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, true);
101+
}
102+
#else
103+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
104+
#endif
105+
} else {
106+
TRTORCH_THROW_ERROR(
107+
"Unable to convert node: " << util::node_info(n)
108+
<< "\nScale factor parameter for upsample_bilinear2d not supported yet.");
109+
}
110+
111+
return true;
112+
}})
113+
} // namespace
114+
} // namespace impl
115+
} // namespace converters
116+
} // namespace conversion
117+
} // namespace core
118+
} // namespace trtorch

core/conversion/evaluators/NodeEvaluatorRegistry.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class NodeEvaluatorRegistry {
3636
"Attempting to override already registered evaluator " << node_kind.toQualString()
3737
<< ", merge implementations instead");
3838
}
39+
for (auto const& e : eval_reg.options.supported_variants) {
40+
registered_evaluator_schemas_.insert(e);
41+
}
3942
evaluator_lut_[node_kind] = std::move(eval_reg);
4043
}
4144

@@ -76,6 +79,13 @@ class NodeEvaluatorRegistry {
7679
return evaluator;
7780
}
7881

82+
std::vector<std::string> GetRegisteredEvaluatorList() {
83+
std::vector<std::string> evaluator_list;
84+
std::copy(
85+
registered_evaluator_schemas_.begin(), registered_evaluator_schemas_.end(), std::back_inserter(evaluator_list));
86+
return evaluator_list;
87+
}
88+
7989
bool EvalAtConversionTime(const torch::jit::Node* n) {
8090
auto evaluator = FindEvaluator(n);
8191
if (evaluator == nullptr) {
@@ -87,6 +97,7 @@ class NodeEvaluatorRegistry {
8797

8898
private:
8999
EvaluatorLUT evaluator_lut_;
100+
std::set<std::string> registered_evaluator_schemas_;
90101
};
91102

92103
NodeEvaluatorRegistry& get_evaluator_registry() {
@@ -99,6 +110,10 @@ bool shouldEvalAtConversionTime(const torch::jit::Node* n) {
99110
return get_evaluator_registry().EvalAtConversionTime(n);
100111
}
101112

113+
std::vector<std::string> getEvaluatorList() {
114+
return get_evaluator_registry().GetRegisteredEvaluatorList();
115+
}
116+
102117
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args) {
103118
auto evaluator = get_evaluator_registry().GetEvaluator(n);
104119
return evaluator(n, args);

core/conversion/evaluators/evaluators.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ typedef std::function<c10::optional<torch::jit::IValue>(const torch::jit::Node*,
3838
struct EvalOptions {
3939
std::set<c10::TypePtr> blacklisted_output_types;
4040
std::vector<c10::OperatorName> valid_schemas;
41+
std::vector<std::string> supported_variants;
4142
EvalOptions() = default;
4243
EvalOptions& blacklistOutputTypes(std::set<c10::TypePtr> types) {
4344
use_options = true;
4445
blacklisted_output_types = types;
4546
return *this;
4647
}
4748
EvalOptions& validSchemas(std::set<std::string> schemas) {
49+
std::copy(schemas.begin(), schemas.end(), std::back_inserter(supported_variants));
4850
use_options = true;
4951
for (auto s : schemas) {
5052
valid_schemas.push_back(torch::jit::parseSchema(s).operator_name());
@@ -72,6 +74,7 @@ struct EvalRegistration {
7274

7375
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args);
7476
bool shouldEvalAtConversionTime(const torch::jit::Node* n);
77+
std::vector<std::string> getEvaluatorList();
7578
void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator);
7679
void register_node_evaluator(EvalRegistration r);
7780

cpp/supportedops/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
cc_binary(
4+
name = "supportedops",
5+
srcs = [
6+
"main.cpp"
7+
],
8+
deps = [
9+
"//cpp/api:trtorch",
10+
"//core/conversion/converters"
11+
],
12+
)

cpp/supportedops/main.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#include "core/conversion/converters/converters.h"
2+
#include "core/conversion/evaluators/evaluators.h"
3+
4+
#include <iostream>
5+
#include <sstream>
6+
#include <string>
7+
#include <vector>
8+
9+
int main(int argc, const char* argv[]) {
10+
std::vector<std::string> converters = trtorch::core::conversion::converters::get_converter_list();
11+
std::vector<std::string> evaluators = trtorch::core::conversion::evaluators::getEvaluatorList();
12+
13+
std::stringstream ss;
14+
15+
ss << R"TITLE(
16+
.. _supported_ops:
17+
18+
=================================
19+
Operators Supported
20+
=================================
21+
22+
)TITLE";
23+
24+
ss << R"SEC(
25+
Operators Currently Supported Through Converters
26+
-------------------------------------------------
27+
28+
)SEC";
29+
30+
for (auto c : converters) {
31+
ss << "- " << c << std::endl;
32+
}
33+
34+
ss << R"SEC(
35+
Operators Currently Supported Through Evaluators
36+
-------------------------------------------------
37+
38+
)SEC";
39+
40+
for (auto e : evaluators) {
41+
ss << "- " << e << std::endl;
42+
}
43+
44+
std::ofstream ofs;
45+
ofs.open(argv[1]);
46+
47+
ofs << ss.rdbuf();
48+
49+
return 0;
50+
}

docsrc/index.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ Contributor Documentation
9393
Indices
9494
----------------
9595

96+
* :ref:`supported_operators`
97+
98+
.. toctree::
99+
:caption: Indices
100+
:maxdepth: 1
101+
:hidden:
102+
103+
indices/supported_ops
104+
96105
* :ref:`genindex`
97106
* :ref:`search`
98107

docsrc/indices/.gitkeep

Whitespace-only changes.

0 commit comments

Comments
 (0)