Skip to content

Commit 4095999

Browse files
authored
Merge pull request #354 from NVIDIA/fix_zeros
Verify zeros evaluator works correctly and detect programs that will have an empty TensorRT engine
2 parents c9baca5 + 0f783da commit 4095999

File tree

3 files changed

+84
-5
lines changed

3 files changed

+84
-5
lines changed

core/conversion/conversion.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,30 @@ std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b) {
433433
return unsupported_ops;
434434
}
435435

436+
std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
437+
std::set<std::string> convertable_ops;
438+
for (const auto n : b->nodes()) {
439+
if (n->kind() == torch::jit::prim::Loop || n->kind() == torch::jit::prim::If ||
440+
converters::node_is_convertable(n)) {
441+
if (n->blocks().size() > 0) {
442+
for (const auto sub_b : n->blocks()) {
443+
auto sub_b_convertable_ops = ConvertableOpsInBlock(sub_b);
444+
convertable_ops.insert(sub_b_convertable_ops.begin(), sub_b_convertable_ops.end());
445+
}
446+
}
447+
if (converters::node_is_convertable(n)) {
448+
auto schema = n->maybeSchema();
449+
TRTORCH_CHECK(
450+
schema, "Unable to get schema for Node " << util::node_info(n) << " (conversion.CheckForConvertableOps)");
451+
std::stringstream ss;
452+
ss << *schema;
453+
convertable_ops.insert(ss.str());
454+
}
455+
}
456+
}
457+
return convertable_ops;
458+
}
459+
436460
bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
437461
auto unsupported_ops = GetUnsupportedOpsInBlock(b);
438462

@@ -448,7 +472,21 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
448472
unsupported_msg << "https://www.github.com/nvidia/TRTorch/issues" << std::endl;
449473
LOG_ERROR(unsupported_msg.str());
450474
return false;
451-
} else {
475+
}
476+
477+
if (ConvertableOpsInBlock(b).size() == 0) {
478+
std::stringstream unsupported_msg;
479+
unsupported_msg
480+
<< "Method requested cannot be compiled by TRTorch.\nThere is no work to be done since the resulting compiled program will contain an engine that is empty."
481+
<< std::endl;
482+
unsupported_msg
483+
<< "This may be because there are no operators that can be added to the TensorRT graph or all operators have a resolved compile time value."
484+
<< std::endl;
485+
LOG_ERROR(unsupported_msg.str());
486+
return false;
487+
}
488+
489+
else {
452490
return true;
453491
}
454492
}

core/conversion/evaluators/aten.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,12 @@ auto aten_registrations TRTORCH_UNUSED =
118118
// aten::zeros(int[] size, *, int? dtype=None, int? layout=None,
119119
// Device? device=None, bool? pin_memory=None) -> (Tensor)
120120
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
121-
auto options = torch::TensorOptions()
122-
.dtype(c10::ScalarType(args.at(n->output(1)).unwrapToInt()))
123-
.layout(torch::kStrided)
124-
.device(torch::kCUDA);
121+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
122+
123+
// Input 1 here is the dtype
124+
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
125+
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
126+
}
125127

126128
auto out_tensor = torch::zeros(args.at(n->input(0)).unwrapToIntList().vec(), options);
127129
return out_tensor;

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,43 @@ TEST(Evaluators, DivFloatEvaluatesCorrectly) {
3636
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
3737

3838
ASSERT_TRUE(jit_results[0] == trt_results[0]);
39+
}
40+
41+
TEST(Evaluators, ZerosEvaluatesCorrectly) {
42+
const auto graph = R"IR(
43+
graph(%x.1 : Tensor):
44+
%2 : None = prim::Constant() # :0:0
45+
%3 : int[] = aten::size(%x.1) # <string>:7:9
46+
%z.1 : Tensor = aten::zeros(%3, %2, %2, %2, %2) # experiments/test_zeros.py:8:12
47+
return (%z.1))IR";
48+
49+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
50+
51+
auto g = std::make_shared<torch::jit::Graph>();
52+
torch::jit::parseIR(graph, &*g);
53+
54+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
55+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
56+
57+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
58+
}
59+
60+
TEST(Evaluators, ZerosDataTypeEvaluatesCorrectly) {
61+
const auto graph = R"IR(
62+
graph(%x.1 : Tensor):
63+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
64+
%3 : None = prim::Constant() # :0:0
65+
%4 : int[] = aten::size(%x.1) # <string>:7:9
66+
%z.1 : Tensor = aten::zeros(%4, %2, %3, %3, %3) # experiments/test_zeros.py:8:12
67+
return (%z.1))IR";
68+
69+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
70+
71+
auto g = std::make_shared<torch::jit::Graph>();
72+
torch::jit::parseIR(graph, &*g);
73+
74+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
75+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
76+
77+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
3978
}

0 commit comments

Comments
 (0)