Skip to content

Commit bf5718d

Browse files
committed
repalce util::arrToTensor with tensor_to_const and remove addSliceInput
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 309b701 commit bf5718d

File tree

5 files changed

+14
-79
lines changed

5 files changed

+14
-79
lines changed

core/conversion/converters/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ cc_library(
5858
deps = [
5959
"@tensorrt//:nvinfer",
6060
"//core/util:prelude",
61-
"//core/util:converter_util",
6261
"//core/conversion/var",
6362
"//core/conversion/tensorcontainer",
6463
"//core/conversion/conversionctx",

core/conversion/converters/impl/expand.cpp

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "NvInfer.h"
22
#include "core/conversion/converters/converters.h"
33
#include "core/conversion/tensorcontainer/TensorContainer.h"
4-
#include "core/util/converter_util.h"
54
#include "core/util/prelude.h"
65
#include "core/util/trt_util.h"
76
#include "torch/torch.h"
@@ -16,25 +15,12 @@ namespace converters {
1615
namespace impl {
1716
namespace {
1817

19-
void addSliceInput(nvinfer1::Dims& dims, int idx, ConversionCtx* ctx, nvinfer1::ISliceLayer* slice) {
20-
int32_t rank = static_cast<int32_t>(dims.nbDims);
21-
int32_t* tmp = new int32_t[rank];
22-
for (int i = 0; i < rank; i++)
23-
tmp[i] = dims.d[i];
24-
const nvinfer1::Dims d{1, {rank}};
25-
const nvinfer1::Weights w{nvinfer1::DataType::kINT32, tmp, rank};
26-
auto t = ctx->net->addConstant(d, w)->getOutput(0);
27-
slice->setInput(idx, *t);
28-
}
29-
3018
nvinfer1::ITensor* concat(int max_rank, int old_rank, ConversionCtx* ctx, nvinfer1::ITensor* tensor) {
3119
if (max_rank - old_rank > 0) {
32-
int32_t* tmp = new int32_t[max_rank - old_rank];
33-
for (int i = 0; i < (max_rank - old_rank); i++)
34-
tmp[i] = 1;
35-
auto max_rank_tensor = util::arrToTensor(tmp, max_rank - old_rank, ctx);
20+
torch::Tensor thOne = torch::tensor(std::vector<int32_t>(max_rank - old_rank, 1), torch::kInt32);
21+
auto one_tensor = tensor_to_const(ctx, thOne);
3622
auto in_shape_tensor = ctx->net->addShape(*tensor)->getOutput(0);
37-
nvinfer1::ITensor* const args[2] = {max_rank_tensor, in_shape_tensor};
23+
nvinfer1::ITensor* const args[2] = {one_tensor, in_shape_tensor};
3824
return ctx->net->addConcatenation(args, 2)->getOutput(0);
3925
} else { // max_rank - old_rank == 0
4026
return ctx->net->addShape(*tensor)->getOutput(0);
@@ -166,7 +152,6 @@ bool add_expand_dynamic(
166152

167153
// Dimensions are right alignment. Eg: an input of [3, 1] and max_rank = 4, the result of concat is [1, 1, 3, 1]
168154
auto new_input_shape_tensor = concat(max_rank, input_rank, ctx, in);
169-
// LOG_DEBUG("Expand layer output tensor shape: " << new_output_shape_tensor->getDimensions());
170155
auto new_output_shape_tensor = expandedDimsTensor;
171156

172157
// Add a reshape layer to expand dims
@@ -176,6 +161,8 @@ bool add_expand_dynamic(
176161
// Start the slicing from beginning of tensor since this is an expand layer
177162
std::vector<int64_t> start_vec(max_rank, 0);
178163
nvinfer1::Dims starts_dim = util::toDims(c10::IntArrayRef(start_vec));
164+
at::Tensor thStart = torch::tensor(util::toVec(starts_dim), torch::kInt32);
165+
auto starts = tensor_to_const(ctx, thStart);
179166

180167
// compute sizes = max(x,y).
181168
auto sizes =
@@ -186,18 +173,17 @@ bool add_expand_dynamic(
186173

187174
// Compute (x > 1 ? 1 : 0) for x in newDims, assuming positive x, using only TensorRT operations.
188175
// min(1, sub(input_shape, 1))
189-
int32_t* one_vector_tmp = new int32_t[1];
190-
one_vector_tmp[0] = 1;
191-
auto one_vector = util::arrToTensor(one_vector_tmp, 1, ctx);
192-
auto x_sub_one = ctx->net->addElementWise(*new_input_shape_tensor, *one_vector, nvinfer1::ElementWiseOperation::kSUB)
176+
torch::Tensor thOne = torch::tensor({1}, torch::kInt32);
177+
auto one_tensor = tensor_to_const(ctx, thOne);
178+
auto x_sub_one = ctx->net->addElementWise(*new_input_shape_tensor, *one_tensor, nvinfer1::ElementWiseOperation::kSUB)
193179
->getOutput(0);
194-
auto strides = ctx->net->addElementWise(*one_vector, *x_sub_one, nvinfer1::ElementWiseOperation::kMIN)->getOutput(0);
180+
auto strides = ctx->net->addElementWise(*one_tensor, *x_sub_one, nvinfer1::ElementWiseOperation::kMIN)->getOutput(0);
195181
nvinfer1::Dims strides_dim{-1, {}};
196182
strides_dim.nbDims = max_rank;
197183

198-
// Slice layer does the expansion in TRT. Desired output size is specified by expandedDimsTensor
184+
// Slice layer does the expansion in TRT. Desired output size is specified by sizes input at index 2.
199185
auto slice = ctx->net->addSlice(*shuffle->getOutput(0), starts_dim, sizes_dim, strides_dim);
200-
addSliceInput(starts_dim, 1, ctx, slice);
186+
slice->setInput(1, *starts);
201187
slice->setInput(2, *sizes);
202188
slice->setInput(3, *strides);
203189

@@ -219,11 +205,8 @@ auto expand_registrations TRTORCH_UNUSED =
219205
auto expandedDims = util::toDims(expanded_size);
220206
LOG_DEBUG("(expand layer) Expand input from " << input_dims << " to " << expandedDims);
221207
if (ctx->input_is_dynamic) {
222-
int expanded_size_rank = static_cast<int>(expanded_size.size());
223-
int32_t* tmp = new int32_t[expanded_size_rank];
224-
for (int i = 0; i < expanded_size_rank; i++)
225-
tmp[i] = expanded_size[i];
226-
auto expandedDimsTensor = util::arrToTensor(tmp, expanded_size_rank, ctx);
208+
at::Tensor thExpanded_size = torch::tensor(expanded_size.vec(), torch::kInt32);
209+
auto expandedDimsTensor = tensor_to_const(ctx, thExpanded_size);
227210
return add_expand_dynamic(ctx, n, in, expandedDimsTensor, expandedDims, true);
228211
} else {
229212
return add_expand(ctx, n, in, expandedDims);

core/util/BUILD

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -84,23 +84,6 @@ cc_library(
8484
})
8585
)
8686

87-
cc_library(
88-
name = "converter_util",
89-
hdrs = [
90-
"converter_util.h",
91-
],
92-
srcs = [
93-
"converter_util.cpp"
94-
],
95-
deps = [
96-
"//core/conversion/conversionctx"
97-
]+ select({
98-
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
99-
"//conditions:default": ["@libtorch//:libtorch"],
100-
}),
101-
)
102-
103-
10487
load("@rules_pkg//:pkg.bzl", "pkg_tar")
10588

10689
pkg_tar(
@@ -112,7 +95,6 @@ pkg_tar(
11295
"//core/util:Exception.h",
11396
"//core/util:prelude.h",
11497
"//core/util:jit_util.h",
115-
"//core/util:trt_util.h",
116-
"//core/util:converter_util.h"
98+
"//core/util:trt_util.h"
11799
],
118100
)

core/util/converter_util.cpp

Lines changed: 0 additions & 15 deletions
This file was deleted.

core/util/converter_util.h

Lines changed: 0 additions & 14 deletions
This file was deleted.

0 commit comments

Comments
 (0)