Skip to content

Commit 787c958

Browse files
authored
Merge pull request #307 from NVIDIA/new_evaluators
Adding aten::is_floating_point evaluator
2 parents 952d090 + 22d65a4 commit 787c958

File tree

3 files changed

+54
-1
lines changed

3 files changed

+54
-1
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,22 @@ auto aten_registrations TRTORCH_UNUSED =
585585
return {};
586586
},
587587
EvalOptions()})
588+
.evaluator({c10::Symbol::fromQualString("aten::is_floating_point"),
589+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
590+
auto tensor_var = args.at(n->input(0));
591+
if (tensor_var.isITensor()) {
592+
auto tensor = tensor_var.ITensor();
593+
auto t = tensor->getType();
594+
return (t == nvinfer1::DataType::kFLOAT || t == nvinfer1::DataType::kHALF);
595+
} else {
596+
auto tensor = tensor_var.unwrapToTensor();
597+
auto t = tensor.scalar_type();
598+
return at::isFloatingType(t);
599+
}
600+
},
601+
EvalOptions().validSchemas({
602+
"aten::is_floating_point(Tensor self) -> (bool)",
603+
})})
588604
.evaluator(
589605
{c10::Symbol::fromQualString("aten::tensor"),
590606
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {

tests/accuracy/datasets/cifar10.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ std::pair<torch::Tensor, torch::Tensor> read_batch(const std::string& path) {
5151
labels.push_back(label);
5252
auto image_tensor =
5353
torch::from_blob(image.data(), {kImageChannels, kImageDim, kImageDim}, torch::TensorOptions().dtype(torch::kU8))
54-
.to(torch::kF32).div(255);
54+
.to(torch::kF32)
55+
.div(255);
5556
images.push_back(image_tensor);
5657
}
5758

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,3 +471,39 @@ TEST(Evaluators, IntFloatEvaluatesCorrectly) {
471471

472472
ASSERT_TRUE(jit_results[0] == trt_results[0]);
473473
}
474+
475+
TEST(Evaluators, ATenIsFloatingPointEvaluatesTrueCorrectly) {
476+
const auto graph = R"IR(
477+
graph(%0 : Tensor):
478+
%1 : bool = aten::is_floating_point(%0)
479+
return (%1))IR";
480+
481+
auto g = std::make_shared<torch::jit::Graph>();
482+
torch::jit::parseIR(graph, &*g);
483+
484+
auto in = at::randint(1, 10, {1, 3, 3, 3}, {at::kCUDA}).to(torch::kF32);
485+
auto in_trt = in.clone();
486+
487+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
488+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in_trt});
489+
490+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
491+
}
492+
493+
TEST(Evaluators, ATenIsFloatingPointEvaluatesFalseCorrectly) {
494+
const auto graph = R"IR(
495+
graph(%0 : Tensor):
496+
%1 : bool = aten::is_floating_point(%0)
497+
return (%1))IR";
498+
499+
auto g = std::make_shared<torch::jit::Graph>();
500+
torch::jit::parseIR(graph, &*g);
501+
502+
auto in = at::randint(1, 10, {1, 3, 3, 3}, {at::kCUDA}).to(torch::kI8);
503+
auto in_trt = in.clone();
504+
505+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
506+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in_trt});
507+
508+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
509+
}

0 commit comments

Comments
 (0)