Skip to content

Commit 8d5b123

Browse files
authored
Merge pull request #118 from abhi-iyer/master
Support for aten::add_.t operation
2 parents be69060 + c4c3ce1 commit 8d5b123

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,36 @@ auto aten_registrations TRTORCH_UNUSED = RegisterNodeEvaluators()
243243
"aten::add.int(int a, int b) -> (int)",
244244
"aten::add.float(float a, float b) -> (float)"
245245
})
246+
}).evaluator({
247+
c10::Symbol::fromQualString("aten::add_"),
248+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
249+
if (args.at(n->input(0)).IValue()->isList()) {
250+
auto a = args.at(n->input(0)).IValue()->toListRef();
251+
auto b = args.at(n->input(1)).IValue()->toListRef();
252+
253+
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
254+
c10::TypePtr elementType = lt->getElementType();
255+
256+
auto merged = c10::impl::GenericList(elementType);
257+
merged.reserve(a.size() + b.size());
258+
259+
for (auto each : a) {
260+
merged.emplace_back(each);
261+
}
262+
263+
for (auto each : b) {
264+
merged.emplace_back(each);
265+
}
266+
267+
return merged;
268+
} else {
269+
TRTORCH_THROW_ERROR("Unimplemented data type for aten::add_ evaluator: " << args.at(n->input(0)).IValue()->type()->str());
270+
return {};
271+
}
272+
},
273+
EvalOptions().validSchemas({
274+
"aten::add_.t(t[](a!) self, t[] b) -> (t[])"
275+
})
246276
}).evaluator({
247277
c10::Symbol::fromQualString("aten::mul"),
248278
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {

0 commit comments

Comments
 (0)