Skip to content

Commit 7cd52cf

Browse files
authored
Merge pull request #325 from inocsin/fix_arange
Add converter aten::arange
2 parents cdab9ec + d19682b commit 7cd52cf

File tree

2 files changed

+158
-1
lines changed

2 files changed

+158
-1
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,61 @@ auto aten_registrations TRTORCH_UNUSED =
467467
LOG_WARNING("Warning from TorchScript: " << *warning);
468468
return {};
469469
},
470-
EvalOptions()});
470+
EvalOptions()})
471+
.evaluator({c10::Symbol::fromQualString("aten::arange"),
472+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
473+
int input_size = n->inputs().size();
474+
int scalar_count = 0;
475+
for (int i = 0; i < input_size; i++) {
476+
if (args.at(n->input(i)).IValue()->isScalar()) {
477+
scalar_count += 1;
478+
}
479+
}
480+
if (scalar_count == 1) {
481+
if (args.at(n->input(0)).IValue()->isInt()) {
482+
int end_scalar = args.at(n->input(0)).unwrapToInt();
483+
return torch::arange(end_scalar);
484+
} else if (args.at(n->input(0)).IValue()->isDouble()) {
485+
float end_scalar = args.at(n->input(0)).unwrapToScalar().to<float>();
486+
return torch::arange(end_scalar);
487+
}
488+
} else if (scalar_count == 2) {
489+
if (args.at(n->input(0)).IValue()->isDouble() || args.at(n->input(1)).IValue()->isDouble()) {
490+
float start_scalar = args.at(n->input(0)).unwrapToScalar().to<float>();
491+
float end_scalar = args.at(n->input(1)).unwrapToScalar().to<float>();
492+
return torch::arange(start_scalar, end_scalar);
493+
} else {
494+
int start_scalar = args.at(n->input(0)).unwrapToInt();
495+
int end_scalar = args.at(n->input(1)).unwrapToInt();
496+
return torch::arange(start_scalar, end_scalar);
497+
}
498+
} else if (scalar_count == 3) {
499+
if (args.at(n->input(0)).IValue()->isDouble() || args.at(n->input(1)).IValue()->isDouble() ||
500+
args.at(n->input(2)).IValue()->isDouble()) {
501+
float start_scalar = args.at(n->input(0)).unwrapToScalar().to<float>();
502+
float end_scalar = args.at(n->input(1)).unwrapToScalar().to<float>();
503+
float step_scalar = args.at(n->input(2)).unwrapToScalar().to<float>();
504+
return torch::arange(start_scalar, end_scalar, step_scalar);
505+
} else {
506+
int start_scalar = args.at(n->input(0)).unwrapToInt();
507+
int end_scalar = args.at(n->input(1)).unwrapToInt();
508+
int step_scalar = args.at(n->input(2)).unwrapToInt();
509+
return torch::arange(start_scalar, end_scalar, step_scalar);
510+
}
511+
} else {
512+
TRTORCH_THROW_ERROR(
513+
"Invalid input argument size for aten::arange, input argument size: " << input_size);
514+
}
515+
return {};
516+
},
517+
EvalOptions().validSchemas({
518+
R"SIG(aten::arange(Scalar end, *, int? dtype=None, int? layout=None,
519+
Device? device=None, bool? pin_memory=None) -> (Tensor))SIG",
520+
R"SIG(aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None,
521+
Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor))SIG",
522+
R"SIG(aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None,
523+
Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor))SIG",
524+
})});
471525
} // namespace
472526
} // namespace evaluators
473527
} // namespace conversion

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,107 @@ TEST(Evaluators, ZerosDataTypeEvaluatesCorrectly) {
7575
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
7676

7777
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
78+
}
79+
80+
TEST(Evaluators, ATenArangeIntEvaluatesCorrectly) {
81+
const auto graph = R"IR(
82+
graph():
83+
%0 : int = prim::Constant[value=51]()
84+
%1 : None = prim::Constant()
85+
%2 : Tensor = aten::arange(%0, %1, %1, %1, %1)
86+
return (%2))IR";
87+
88+
auto g = std::make_shared<torch::jit::Graph>();
89+
torch::jit::parseIR(graph, &*g);
90+
91+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
92+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
93+
94+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
95+
}
96+
97+
TEST(Evaluators, ATenArangeFloatEvaluatesCorrectly) {
98+
const auto graph = R"IR(
99+
graph():
100+
%0 : float = prim::Constant[value=51.2]()
101+
%1 : None = prim::Constant()
102+
%2 : Tensor = aten::arange(%0, %1, %1, %1, %1)
103+
return (%2))IR";
104+
105+
auto g = std::make_shared<torch::jit::Graph>();
106+
torch::jit::parseIR(graph, &*g);
107+
108+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
109+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
110+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
111+
}
112+
113+
TEST(Evaluators, ATenArangeStartEndIntEvaluatesCorrectly) {
114+
const auto graph = R"IR(
115+
graph():
116+
%0 : int = prim::Constant[value=1]()
117+
%1 : int = prim::Constant[value=51]()
118+
%2 : None = prim::Constant()
119+
%3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2)
120+
return (%3))IR";
121+
122+
auto g = std::make_shared<torch::jit::Graph>();
123+
torch::jit::parseIR(graph, &*g);
124+
125+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
126+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
127+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
128+
}
129+
130+
TEST(Evaluators, ATenArangeStartEndFloatEvaluatesCorrectly) {
131+
const auto graph = R"IR(
132+
graph():
133+
%0 : float = prim::Constant[value=1.5]()
134+
%1 : float = prim::Constant[value=51.2]()
135+
%2 : None = prim::Constant()
136+
%3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2)
137+
return (%3))IR";
138+
139+
auto g = std::make_shared<torch::jit::Graph>();
140+
torch::jit::parseIR(graph, &*g);
141+
142+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
143+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
144+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
145+
}
146+
147+
TEST(Evaluators, ATenArangeStartEndStepIntEvaluatesCorrectly) {
148+
const auto graph = R"IR(
149+
graph():
150+
%0 : int = prim::Constant[value=1]()
151+
%1 : int = prim::Constant[value=51]()
152+
%2 : int = prim::Constant[value=1]()
153+
%3 : None = prim::Constant()
154+
%4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3)
155+
return (%4))IR";
156+
157+
auto g = std::make_shared<torch::jit::Graph>();
158+
torch::jit::parseIR(graph, &*g);
159+
160+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
161+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
162+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
163+
}
164+
165+
TEST(Evaluators, ATenArangeStartEndStepFloatEvaluatesCorrectly) {
166+
const auto graph = R"IR(
167+
graph():
168+
%0 : float = prim::Constant[value=1.2]()
169+
%1 : float = prim::Constant[value=51.6]()
170+
%2 : float = prim::Constant[value=1.5]()
171+
%3 : None = prim::Constant()
172+
%4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3)
173+
return (%4))IR";
174+
175+
auto g = std::make_shared<torch::jit::Graph>();
176+
torch::jit::parseIR(graph, &*g);
177+
178+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
179+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
180+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
78181
}

0 commit comments

Comments
 (0)