Skip to content

Commit 410ee2a

Browse files
inocsinnarendasan
authored andcommitted
support gt/lt/eq/ge/le converters
Signed-off-by: inocsin <[email protected]>
1 parent 3e1cc88 commit 410ee2a

File tree

2 files changed

+266
-0
lines changed

2 files changed

+266
-0
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,187 @@ auto element_wise_registrations TRTORCH_UNUSED =
352352
pow->setName(util::node_info(n).c_str());
353353
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], pow->getOutput(0));
354354

355+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
356+
return true;
357+
}})
358+
.pattern({"aten::gt.Tensor(Tensor self, Tensor other) -> (Tensor)",
359+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
360+
// TODO: Remove with functionalization
361+
auto self = args[0].ITensorOrFreeze(ctx);
362+
auto other = args[1].ITensorOrFreeze(ctx);
363+
auto gt =
364+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n));
365+
TRTORCH_CHECK(gt, "Unable to create greater layer from node: " << *n);
366+
367+
gt->setName(util::node_info(n).c_str());
368+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gt->getOutput(0));
369+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
370+
return true;
371+
}})
372+
.pattern({"aten::gt.Scalar(Tensor self, Scalar other) -> (Tensor)",
373+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
374+
// TODO: Remove with functionalization
375+
auto self = args[0].ITensorOrFreeze(ctx);
376+
auto otherScalar = args[1].unwrapToScalar().to<float>();
377+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
378+
auto gt =
379+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n));
380+
TRTORCH_CHECK(gt, "Unable to create greater layer from node: " << *n);
381+
382+
gt->setName(util::node_info(n).c_str());
383+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gt->getOutput(0));
384+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
385+
return true;
386+
}})
387+
.pattern({"aten::lt.Tensor(Tensor self, Tensor other) -> (Tensor)",
388+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
389+
// TODO: Remove with functionalization
390+
auto self = args[0].ITensorOrFreeze(ctx);
391+
auto other = args[1].ITensorOrFreeze(ctx);
392+
auto lt =
393+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n));
394+
TRTORCH_CHECK(lt, "Unable to create less layer from node: " << *n);
395+
396+
lt->setName(util::node_info(n).c_str());
397+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], lt->getOutput(0));
398+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
399+
return true;
400+
}})
401+
.pattern({"aten::lt.Scalar(Tensor self, Scalar other) -> (Tensor)",
402+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
403+
// TODO: Remove with functionalization
404+
auto self = args[0].ITensorOrFreeze(ctx);
405+
auto otherScalar = args[1].unwrapToScalar().to<float>();
406+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
407+
auto lt =
408+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n));
409+
TRTORCH_CHECK(lt, "Unable to create less layer from node: " << *n);
410+
411+
lt->setName(util::node_info(n).c_str());
412+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], lt->getOutput(0));
413+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
414+
return true;
415+
}})
416+
.pattern({"aten::eq.Tensor(Tensor self, Tensor other) -> (Tensor)",
417+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
418+
// TODO: Remove with functionalization
419+
auto self = args[0].ITensorOrFreeze(ctx);
420+
auto other = args[1].ITensorOrFreeze(ctx);
421+
auto eq =
422+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n));
423+
TRTORCH_CHECK(eq, "Unable to create equal layer from node: " << *n);
424+
425+
eq->setName(util::node_info(n).c_str());
426+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], eq->getOutput(0));
427+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
428+
return true;
429+
}})
430+
.pattern({"aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)",
431+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
432+
// TODO: Remove with functionalization
433+
auto self = args[0].ITensorOrFreeze(ctx);
434+
auto otherScalar = args[1].unwrapToScalar().to<float>();
435+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
436+
auto eq =
437+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n));
438+
TRTORCH_CHECK(eq, "Unable to create equal layer from node: " << *n);
439+
440+
eq->setName(util::node_info(n).c_str());
441+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], eq->getOutput(0));
442+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
443+
return true;
444+
}})
445+
.pattern({"aten::ge.Tensor(Tensor self, Tensor other) -> (Tensor)",
446+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
447+
// TODO: Remove with functionalization
448+
auto self = args[0].ITensorOrFreeze(ctx);
449+
auto other = args[1].ITensorOrFreeze(ctx);
450+
451+
auto greater =
452+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n)+"_greater");
453+
TRTORCH_CHECK(greater, "Unable to create Greater layer from node: " << *n);
454+
455+
auto equal =
456+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n)+"_equal");
457+
TRTORCH_CHECK(equal, "Unable to create Equal layer from node: " << *n);
458+
459+
auto or_op = ctx->net->addElementWise(*greater->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR);
460+
461+
TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n);
462+
or_op->setName(util::node_info(n).c_str());
463+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0));
464+
465+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
466+
return true;
467+
}})
468+
.pattern({"aten::ge.Scalar(Tensor self, Scalar other) -> (Tensor)",
469+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
470+
// TODO: Remove with functionalization
471+
auto self = args[0].ITensorOrFreeze(ctx);
472+
auto otherScalar = args[1].unwrapToScalar().to<float>();
473+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
474+
475+
auto greater =
476+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n)+"_greater");
477+
TRTORCH_CHECK(greater, "Unable to create Greater layer from node: " << *n);
478+
479+
auto equal =
480+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n)+"_equal");
481+
TRTORCH_CHECK(equal, "Unable to create Equal layer from node: " << *n);
482+
483+
auto or_op = ctx->net->addElementWise(*greater->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR);
484+
485+
TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n);
486+
or_op->setName(util::node_info(n).c_str());
487+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0));
488+
489+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
490+
return true;
491+
}})
492+
.pattern({"aten::le.Tensor(Tensor self, Tensor other) -> (Tensor)",
493+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
494+
// TODO: Remove with functionalization
495+
auto self = args[0].ITensorOrFreeze(ctx);
496+
auto other = args[1].ITensorOrFreeze(ctx);
497+
498+
auto less =
499+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n)+"_less");
500+
TRTORCH_CHECK(less, "Unable to create Less layer from node: " << *n);
501+
502+
auto equal =
503+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n)+"_equal");
504+
TRTORCH_CHECK(equal, "Unable to create Equal layer from node: " << *n);
505+
506+
auto or_op = ctx->net->addElementWise(*less->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR);
507+
508+
TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n);
509+
or_op->setName(util::node_info(n).c_str());
510+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0));
511+
512+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
513+
return true;
514+
}})
515+
.pattern({"aten::le.Scalar(Tensor self, Scalar other) -> (Tensor)",
516+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
517+
// TODO: Remove with functionalization
518+
auto self = args[0].ITensorOrFreeze(ctx);
519+
auto otherScalar = args[1].unwrapToScalar().to<float>();
520+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
521+
522+
auto less =
523+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n)+"_less");
524+
TRTORCH_CHECK(less, "Unable to create Less layer from node: " << *n);
525+
526+
auto equal =
527+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n)+"_equal");
528+
TRTORCH_CHECK(equal, "Unable to create Equal layer from node: " << *n);
529+
530+
auto or_op = ctx->net->addElementWise(*less->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR);
531+
532+
TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n);
533+
or_op->setName(util::node_info(n).c_str());
534+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0));
535+
355536
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
356537
return true;
357538
}});

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,88 @@ TEST(Converters, ATenNeScalarConvertsCorrectly) {
163163
pointwise_test_helper(graph, true, false, {3, 4, 2});
164164
;
165165
}
166+
167+
TEST(Converters, ATenGreaterThanConvertsCorrectly) {
168+
const auto graph = R"IR(
169+
graph(%0 : Tensor, %1 : Tensor):
170+
%2 : Tensor = aten::gt(%0, %1)
171+
return (%2))IR";
172+
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5});
173+
}
174+
175+
TEST(Converters, ATenGreaterThanScalarConvertsCorrectly) {
176+
const auto graph = R"IR(
177+
graph(%0 : Tensor):
178+
%scalar : float = prim::Constant[value=3]()
179+
%2 : Tensor = aten::gt(%0, %scalar)
180+
return (%2))IR";
181+
pointwise_test_helper(graph, true, false, {5, 5});
182+
}
183+
184+
TEST(Converters, ATenLessThanConvertsCorrectly) {
185+
const auto graph = R"IR(
186+
graph(%0 : Tensor, %1 : Tensor):
187+
%2 : Tensor = aten::lt(%0, %1)
188+
return (%2))IR";
189+
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5});
190+
}
191+
192+
TEST(Converters, ATenLessThanScalarConvertsCorrectly) {
193+
const auto graph = R"IR(
194+
graph(%0 : Tensor):
195+
%scalar : float = prim::Constant[value=3]()
196+
%2 : Tensor = aten::lt(%0, %scalar)
197+
return (%2))IR";
198+
pointwise_test_helper(graph, true, false, {5, 5});
199+
}
200+
201+
TEST(Converters, ATenEqualConvertsCorrectly) {
202+
const auto graph = R"IR(
203+
graph(%0 : Tensor, %1 : Tensor):
204+
%2 : Tensor = aten::eq(%0, %1)
205+
return (%2))IR";
206+
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5});
207+
}
208+
209+
TEST(Converters, ATenEqualScalarConvertsCorrectly) {
210+
const auto graph = R"IR(
211+
graph(%0 : Tensor):
212+
%scalar : float = prim::Constant[value=3]()
213+
%2 : Tensor = aten::eq(%0, %scalar)
214+
return (%2))IR";
215+
pointwise_test_helper(graph, true, false, {5, 5});
216+
}
217+
218+
TEST(Converters, ATenGEConvertsCorrectly) {
219+
const auto graph = R"IR(
220+
graph(%0 : Tensor, %1 : Tensor):
221+
%2 : Tensor = aten::ge(%0, %1)
222+
return (%2))IR";
223+
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5});
224+
}
225+
226+
TEST(Converters, ATenGEScalarConvertsCorrectly) {
227+
const auto graph = R"IR(
228+
graph(%0 : Tensor):
229+
%scalar : float = prim::Constant[value=3]()
230+
%2 : Tensor = aten::ge(%0, %scalar)
231+
return (%2))IR";
232+
pointwise_test_helper(graph, true, false, {5, 5});
233+
}
234+
235+
TEST(Converters, ATenLEConvertsCorrectly) {
236+
const auto graph = R"IR(
237+
graph(%0 : Tensor, %1 : Tensor):
238+
%2 : Tensor = aten::le(%0, %1)
239+
return (%2))IR";
240+
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5});
241+
}
242+
243+
TEST(Converters, ATenLEScalarConvertsCorrectly) {
244+
const auto graph = R"IR(
245+
graph(%0 : Tensor):
246+
%scalar : float = prim::Constant[value=3]()
247+
%2 : Tensor = aten::le(%0, %scalar)
248+
return (%2))IR";
249+
pointwise_test_helper(graph, true, false, {5, 5});
250+
}

0 commit comments

Comments
 (0)