File tree Expand file tree Collapse file tree 1 file changed +30
-0
lines changed
core/conversion/evaluators Expand file tree Collapse file tree 1 file changed +30
-0
lines changed Original file line number Diff line number Diff line change @@ -243,6 +243,36 @@ auto aten_registrations TRTORCH_UNUSED = RegisterNodeEvaluators()
243
243
" aten::add.int(int a, int b) -> (int)" ,
244
244
" aten::add.float(float a, float b) -> (float)"
245
245
})
246
+ }).evaluator({
247
+ c10::Symbol::fromQualString (" aten::add_" ),
248
+ [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
249
+ if (args.at (n->input (0 )).IValue ()->isList ()) {
250
+ auto a = args.at (n->input (0 )).IValue ()->toListRef ();
251
+ auto b = args.at (n->input (1 )).IValue ()->toListRef ();
252
+
253
+ c10::ListTypePtr lt = n->output ()->type ()->expect <c10::ListType>();
254
+ c10::TypePtr elementType = lt->getElementType ();
255
+
256
+ auto merged = c10::impl::GenericList (elementType);
257
+ merged.reserve (a.size () + b.size ());
258
+
259
+ for (auto each : a) {
260
+ merged.emplace_back (each);
261
+ }
262
+
263
+ for (auto each : b) {
264
+ merged.emplace_back (each);
265
+ }
266
+
267
+ return merged;
268
+ } else {
269
+ TRTORCH_THROW_ERROR (" Unimplemented data type for aten::add_ evaluator: " << args.at (n->input (0 )).IValue ()->type ()->str ());
270
+ return {};
271
+ }
272
+ },
273
+ EvalOptions ().validSchemas ({
274
+ " aten::add_.t(t[](a!) self, t[] b) -> (t[])"
275
+ })
246
276
}).evaluator({
247
277
c10::Symbol::fromQualString (" aten::mul" ),
248
278
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
You can’t perform that action at this time.
0 commit comments