Skip to content

Commit 298b5bb

Browse files
committed
refactor(//core/conversion/converters): Move to key converters by
c10::OperatorName allowing support for different converters for overloaded operators Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 6ab9814 commit 298b5bb

File tree

2 files changed

+84
-16
lines changed

2 files changed

+84
-16
lines changed

core/conversion/converters/NodeConverterRegistry.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,20 @@ std::string canonical_schema_string(const torch::jit::FunctionSchema& schema) {
4141
}
4242

4343
namespace {
44-
using ConverterLUT = std::unordered_map<torch::jit::Symbol, OpConverter>;
44+
using ConverterLUT = std::unordered_map<c10::OperatorName, OpConverter>;
4545

4646
class NodeConverterRegistry {
4747
public:
4848
bool RegisterConverter(torch::jit::FunctionSchema* signature, OpConverter& converter) {
4949
LOG_DEBUG("Registering Converter for " << canonical_schema_string(*signature));
50-
auto sym = torch::jit::Symbol::fromQualString(signature->name());
51-
converter_lut_[sym] = std::move(converter);
50+
auto name = signature->operator_name();
51+
converter_lut_[name] = std::move(converter);
5252
return true;
5353
}
5454

5555
OpConverter GetConverter(const torch::jit::FunctionSchema* signature) {
56-
auto sym = torch::jit::Symbol::fromQualString(signature->name());
57-
auto iter = converter_lut_.find(sym);
56+
auto name = signature->operator_name();
57+
auto iter = converter_lut_.find(name);
5858
if (iter == converter_lut_.end()) {
5959
LOG_ERROR("Requested converter for " << signature->name() << ", but no such converter was found");
6060
// ASK: Is there a better way than returning a nullptr?
@@ -66,8 +66,8 @@ class NodeConverterRegistry {
6666
bool Convertable(const torch::jit::Node* n) {
6767
auto schema = n->maybeSchema();
6868
if (schema) {
69-
auto sym = torch::jit::Symbol::fromQualString(schema->name());
70-
auto iter = converter_lut_.find(sym);
69+
auto name = schema->operator_name();
70+
auto iter = converter_lut_.find(name);
7171
if (iter == converter_lut_.end()) {
7272
return false;
7373
} else {
@@ -79,7 +79,7 @@ class NodeConverterRegistry {
7979
return false;
8080
}
8181
}
82-
82+
8383
private:
8484
ConverterLUT converter_lut_;
8585
};
@@ -111,7 +111,7 @@ OpConverter get_node_converter_for(const torch::jit::FunctionSchema* signature)
111111
bool node_is_convertable(const torch::jit::Node* n) {
112112
return get_converter_registry().Convertable(n);
113113
}
114-
114+
115115
RegisterNodeConversionPatterns&& RegisterNodeConversionPatterns::pattern(ConversionPattern p) && {
116116
register_node_converter(std::move(p));
117117
return std::move(*this);

core/conversion/converters/impl/element_wise.cpp

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
6868
TRTORCH_CHECK(add, "Unable to create add layer from node: " << *n);
6969

7070
add->setName(util::node_info(n).c_str());
71-
auto out = associate_value_and_tensor(ctx, n->outputs()[0], add->getOutput(0));
71+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], add->getOutput(0));
7272

7373
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
7474
return true;
@@ -85,7 +85,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
8585
TRTORCH_CHECK(add, "Unable to create add layer from node: " << *n);
8686

8787
add->setName(util::node_info(n).c_str());
88-
auto out = associate_value_and_tensor(ctx, n->outputs()[0], add->getOutput(0));
88+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], add->getOutput(0));
8989

9090
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
9191
return true;
@@ -102,13 +102,13 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
102102
TRTORCH_CHECK(sub, "Unable to create sub layer from node: " << *n);
103103

104104
sub->setName(util::node_info(n).c_str());
105-
auto out = associate_value_and_tensor(ctx, n->outputs()[0], sub->getOutput(0));
105+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], sub->getOutput(0));
106106

107107
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
108108
return true;
109109
}
110110
}).pattern({
111-
"aten::div(Tensor self, Tensor other) -> Tensor",
111+
"aten::div.Tensor(Tensor self, Tensor other) -> Tensor",
112112
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
113113
// Should implement self / other
114114
auto self = args[0].ITensor();
@@ -118,13 +118,29 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
118118
TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);
119119

120120
div->setName(util::node_info(n).c_str());
121-
auto out = associate_value_and_tensor(ctx, n->outputs()[0], div->getOutput(0));
121+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
122122

123123
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
124124
return true;
125125
}
126126
}).pattern({
127-
"aten::mul(Tensor self, Tensor other) -> Tensor",
127+
"aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)",
128+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
129+
// TODO: Remove with functionalization
130+
auto self = args[0].ITensor();
131+
auto other = args[1].ITensor();
132+
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other);
133+
134+
TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);
135+
136+
div->setName(util::node_info(n).c_str());
137+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
138+
139+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
140+
return true;
141+
}
142+
}).pattern({
143+
"aten::mul.Tensor(Tensor self, Tensor other) -> Tensor",
128144
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
129145
// Should implement self * other
130146
auto self = args[0].ITensor();
@@ -134,13 +150,65 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
134150
TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);
135151

136152
mul->setName(util::node_info(n).c_str());
137-
auto out = associate_value_and_tensor(ctx, n->outputs()[0], mul->getOutput(0));
153+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0));
154+
155+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
156+
return true;
157+
}
158+
}).pattern({
159+
"aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)",
160+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
161+
// TODO: Remove with functionalization
162+
auto self = args[0].ITensor();
163+
auto other = args[1].ITensor();
164+
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other);
165+
166+
TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);
167+
168+
mul->setName(util::node_info(n).c_str());
169+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0));
138170

139171
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
140172
return true;
141173
}
142174
});
143175

176+
// - func: div.Tensor(Tensor self, Tensor other) -> Tensor
177+
// use_c10_dispatcher: full
178+
// variants: function, method
179+
// dispatch:
180+
// CPU: div
181+
// CUDA: div
182+
// SparseCPU: div_sparse
183+
// SparseCUDA: div_sparse
184+
// supports_named_tensor: True
185+
186+
// - func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
187+
// variants: method
188+
// dispatch:
189+
// CPU: div_
190+
// CUDA: div_
191+
// SparseCPU: div_sparse_
192+
// SparseCUDA: div_sparse_
193+
// supports_named_tensor: True
194+
195+
// - func: div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
196+
// dispatch:
197+
// CPU: div_out
198+
// CUDA: div_out
199+
// SparseCPU: div_out_sparse_zerodim
200+
// SparseCUDA: div_out_sparse_zerodim
201+
// supports_named_tensor: True
202+
203+
// # For C++ only, until we have conversion from C++ numbers to Tensor
204+
// - func: div.Scalar(Tensor self, Scalar other) -> Tensor
205+
// use_c10_dispatcher: full
206+
// variants: function, method
207+
// supports_named_tensor: True
208+
209+
// - func: div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
210+
// variants: method
211+
// supports_named_tensor: True
144212

145213
} // namespace
146214
} // namespace impl

0 commit comments

Comments
 (0)