Skip to content

Commit c4c3ce1

Browse files
committed
feat(aten::add_t): aten::add_.t evaluator that adds lists together
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent f216d3f commit c4c3ce1

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
#include "core/conversion/evaluators/evaluators.h"
1111
#include "core/conversion/evaluators/eval_macros.h"
1212

13-
// #include <csignal>
14-
1513
namespace trtorch {
1614
namespace core {
1715
namespace conversion {
@@ -247,21 +245,29 @@ auto aten_registrations TRTORCH_UNUSED = RegisterNodeEvaluators()
247245
})
248246
}).evaluator({
249247
c10::Symbol::fromQualString("aten::add_"),
250-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
251-
LOG_DEBUG("aten::add_ evaluator is found");
248+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
249+
if (args.at(n->input(0)).IValue()->isList()) {
250+
auto a = args.at(n->input(0)).IValue()->toListRef();
251+
auto b = args.at(n->input(1)).IValue()->toListRef();
252252

253-
// std::raise(SIGINT);
253+
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
254+
c10::TypePtr elementType = lt->getElementType();
254255

255-
if (args.at(n->input(0)).IValue()->isList()) {
256-
auto a = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
257-
auto b = args.at(n->input(1)).IValue()->to<c10::List<c10::IValue>>();
256+
auto merged = c10::impl::GenericList(elementType);
257+
merged.reserve(a.size() + b.size());
258258

259-
// incorrect syntax
260-
// for (auto each : b) {
261-
// a.push_back(each);
262-
// }
259+
for (auto each : a) {
260+
merged.emplace_back(each);
261+
}
263262

264-
return a;
263+
for (auto each : b) {
264+
merged.emplace_back(each);
265+
}
266+
267+
return merged;
268+
} else {
269+
TRTORCH_THROW_ERROR("Unimplemented data type for aten::add_ evaluator: " << args.at(n->input(0)).IValue()->type()->str());
270+
return {};
265271
}
266272
},
267273
EvalOptions().validSchemas({

0 commit comments

Comments
 (0)