Skip to content

Commit 1d55db2

Browse files
authored
Merge pull request #382 from NVIDIA/aten_to
feat(aten::to): Add support for cast layer conversion
2 parents bf92101 + 3e87c74 commit 1d55db2

File tree

19 files changed

+412
-65
lines changed

19 files changed

+412
-65
lines changed

core/conversion/conversion_ignorelist.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ const std::unordered_set<std::string>& get_non_convertable_nodes() {
1616
"aten::backward",
1717
"aten::save",
1818
"aten::contiguous",
19-
"aten::to",
2019
"prim::RaiseException",
2120
"prim::Print",
2221
"prim::device",

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ cc_library(
5454
"NodeConverterRegistry.cpp",
5555
"impl/activation.cpp",
5656
"impl/batch_norm.cpp",
57+
"impl/cast.cpp",
5758
"impl/concat.cpp",
5859
"impl/constant.cpp",
5960
"impl/constant_pad.cpp",

core/conversion/converters/converter_util.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nv
142142
}
143143
}
144144

145-
nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t) {
145+
nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t, const std::string& name) {
146146
bool post_freeze_cast = false;
147147
nvinfer1::DataType post_freeze_cast_type = nvinfer1::DataType::kFLOAT;
148148
// Other "unsupported weights types" can be added to this check here
@@ -175,9 +175,15 @@ nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t) {
175175

176176
std::ostringstream tensor_id;
177177
tensor_id << reinterpret_cast<int*>(out);
178+
std::string tensor_name;
178179

179-
LOG_DEBUG(ctx->logger, "Freezing tensor " << tensor_id.str() << " as an IConstantLayer");
180-
const_layer->setName(("[Freeze Tensor " + tensor_id.str() + " ]").c_str());
180+
if (!name.empty()) {
181+
tensor_name = name;
182+
} else {
183+
tensor_name = tensor_id.str();
184+
}
185+
LOG_DEBUG(ctx->logger, "Freezing tensor " << tensor_name << " as an IConstantLayer");
186+
const_layer->setName(("[Freeze Tensor " + tensor_name + " ]").c_str());
181187

182188
if (post_freeze_cast) {
183189
out = castITensor(ctx, out, post_freeze_cast_type);

core/conversion/converters/converter_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ nvinfer1::ILayer* add_elementwise(
4545
nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype);
4646

4747
// Freeze an at::Tensor in a IConstant layer
48-
nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t);
48+
nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t, const std::string& name = std::string());
4949

5050
} // namespace converters
5151
} // namespace conversion
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#include <torch/torch.h>
2+
#include "core/conversion/converters/converter_util.h"
3+
#include "core/conversion/converters/converters.h"
4+
#include "core/util/prelude.h"
5+
#include "core/util/trt_util.h"
6+
7+
namespace trtorch {
8+
namespace core {
9+
namespace conversion {
10+
namespace converters {
11+
namespace impl {
12+
namespace {
13+
14+
auto cast_registrations TRTORCH_UNUSED =
15+
RegisterNodeConversionPatterns()
16+
.pattern(
17+
{"aten::to.dtype(Tensor self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor)",
18+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19+
auto self = args[0].ITensorOrFreeze(ctx);
20+
auto output_dtype = args[1].unwrapToScalar().to<int64_t>();
21+
auto trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(output_dtype));
22+
auto casted_itensor = castITensor(ctx, self, trt_dtype);
23+
auto output = ctx->AssociateValueAndTensor(n->outputs()[0], casted_itensor);
24+
LOG_DEBUG("[aten::to.dtype] Output tensor shape: " << output->getDimensions());
25+
26+
return true;
27+
}})
28+
.pattern(
29+
{"aten::to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor)",
30+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
31+
auto self = args[0].ITensorOrFreeze(ctx);
32+
nvinfer1::DataType other_dtype = args[1].ITensorOrFreeze(ctx)->getType();
33+
auto casted_itensor = castITensor(ctx, self, other_dtype);
34+
auto output = ctx->AssociateValueAndTensor(n->outputs()[0], casted_itensor);
35+
LOG_DEBUG("[aten::to.other] Output tensor shape: " << output->getDimensions());
36+
37+
return true;
38+
}})
39+
.pattern(
40+
{"aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor(b|a))",
41+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
42+
auto self = args[0].ITensorOrFreeze(ctx);
43+
if (args[2].isIValue() && !args[2].IValue()->isScalar()) {
44+
auto output = ctx->AssociateValueAndTensor(n->outputs()[0], self);
45+
LOG_DEBUG("[aten::to.prim_Device] Output tensor shape: " << output->getDimensions());
46+
return true;
47+
}
48+
49+
auto output_dtype = args[2].unwrapToScalar().to<int64_t>();
50+
auto trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(output_dtype));
51+
auto casted_itensor = castITensor(ctx, self, trt_dtype);
52+
auto output = ctx->AssociateValueAndTensor(n->outputs()[0], casted_itensor);
53+
LOG_DEBUG("[aten::to.prim_Device] Output tensor shape: " << output->getDimensions());
54+
55+
return true;
56+
}});
57+
// clang-format on
58+
} // namespace
59+
} // namespace impl
60+
} // namespace converters
61+
} // namespace conversion
62+
} // namespace core
63+
} // namespace trtorch

core/conversion/converters/impl/constant.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@ auto constant_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1616
// used for Fundimentally this is because of the differing
1717
// philosophies between TensorRT and PyTorch, i.e. Variables contain
1818
// Tensors vs. just Tensors
19-
20-
auto t = args[0].unwrapToTensor();
21-
auto const_out = ctx->AssociateValueAndTensor(n->outputs()[0], tensor_to_const(ctx, t));
22-
23-
LOG_DEBUG("Output tensor shape: " << const_out->getDimensions());
24-
19+
nvinfer1::ITensor* output;
20+
if (args[0].isITensor()){
21+
output = ctx->AssociateValueAndTensor(n->outputs()[0], args[0].ITensor());
22+
} else{
23+
auto t = args[0].unwrapToTensor();
24+
auto const_out = tensor_to_const(ctx, t, util::node_info(n).c_str());
25+
output = ctx->AssociateValueAndTensor(n->outputs()[0], const_out);
26+
}
27+
LOG_DEBUG("Output tensor shape: " << output->getDimensions());
28+
2529
return true;
2630
}});
2731
// clang-format on

core/conversion/converters/impl/shuffle.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,33 @@ static auto shuffle_registrations TRTORCH_UNUSED =
125125
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
126126
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
127127

128+
return true;
129+
}})
130+
.pattern({"aten::t(Tensor self) -> Tensor",
131+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
132+
auto in = args[0].ITensorOrFreeze(ctx);
133+
auto input_dims = in->getDimensions();
134+
// For input tensors < 2D, return them as is
135+
// For a 2D input tensor, return transpose(input, 0, 1) which is a general 2d matrix transpose.
136+
if (input_dims.nbDims < 2) {
137+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], in);
138+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
139+
return true;
140+
}
141+
142+
auto shuffle_layer = ctx->net->addShuffle(*in);
143+
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
144+
nvinfer1::Permutation firstPerm;
145+
firstPerm.order[0] = 1;
146+
firstPerm.order[1] = 0;
147+
148+
shuffle_layer->setFirstTranspose(firstPerm);
149+
shuffle_layer->setZeroIsPlaceholder(false);
150+
shuffle_layer->setName(util::node_info(n).c_str());
151+
152+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_layer->getOutput(0));
153+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
154+
128155
return true;
129156
}})
130157
.pattern({"aten::pixel_shuffle(Tensor self, int upscale_factor) -> (Tensor)",

core/conversion/evaluators/aten.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -468,20 +468,6 @@ auto aten_registrations TRTORCH_UNUSED =
468468
EvalOptions().validSchemas({
469469
"aten::numel(Tensor self) -> int",
470470
})})
471-
.evaluator({c10::Symbol::fromQualString("aten::t"),
472-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
473-
auto tensor_var = args.at(n->input(0));
474-
if (tensor_var.IValue()->isTensor()) {
475-
auto tensor = tensor_var.unwrapToTensor();
476-
return tensor.t();
477-
} else {
478-
TRTORCH_THROW_ERROR("Unimplemented data type for aten::t evaluator: ITensor");
479-
return {};
480-
}
481-
},
482-
EvalOptions().validSchemas({
483-
"aten::t(Tensor self) -> Tensor",
484-
})})
485471
.evaluator({c10::Symbol::fromQualString("aten::dim"),
486472
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
487473
auto tensor_var = args.at(n->input(0));

core/conversion/evaluators/prim.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "core/conversion/evaluators/eval_macros.h"
1313
#include "core/conversion/evaluators/eval_util.h"
1414
#include "core/conversion/evaluators/evaluators.h"
15+
#include "core/util/trt_util.h"
1516

1617
namespace trtorch {
1718
namespace core {
@@ -101,6 +102,28 @@ auto prim_registrations =
101102
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
102103
}
103104
}})
105+
.evaluator({c10::Symbol::fromQualString("prim::dtype"),
106+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
107+
auto input = args.at(n->input(0));
108+
if (input.isITensor()) {
109+
auto trt_dtype = input.ITensor()->getType();
110+
return static_cast<int>(util::TRTDataTypeToScalarType(trt_dtype));
111+
} else if (input.isIValue()) {
112+
if (input.IValue()->isTensor()) {
113+
auto pyt_input = input.IValue()->toTensor();
114+
return static_cast<int>(pyt_input.scalar_type());
115+
} else {
116+
TRTORCH_THROW_ERROR("Unsupported input type in prim::dtype operator");
117+
return {};
118+
}
119+
} else {
120+
TRTORCH_THROW_ERROR("Unsupported input type in prim::dtype operator");
121+
return {};
122+
}
123+
},
124+
EvalOptions().validSchemas({
125+
"prim::dtype(Tensor a) -> (int)",
126+
})})
104127
.evaluator({c10::Symbol::fromQualString("prim::min"),
105128
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
106129
if (n->inputs().size() == 1) {

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
3535
passes::EliminateExceptionOrPassPattern(g);
3636
torch::jit::FuseLinear(g);
3737
torch::jit::LowerAllTuples(g);
38+
passes::ReduceToOperation(g);
3839
passes::RemoveContiguous(g);
3940
passes::RemoveDropout(g);
4041
passes::LinearToAddMM(g);

0 commit comments

Comments
 (0)