Skip to content

Commit da15d9a

Browse files
committed
replace int with int64_t and rm truncate long/double to int/float in Var.cpp
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 3990787 commit da15d9a

File tree

2 files changed

+4
-11
lines changed

2 files changed

+4
-11
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s
2828
auto split_size = args[1].unwrapToInt();
2929
numOutputs = inDimSize / split_size;
3030
numRemainder = inDimSize % split_size;
31-
for (int i = 0; i < numOutputs; i++) {
31+
for (int64_t i = 0; i < numOutputs; i++) {
3232
sizes.push_back(split_size);
3333
}
3434
if (numRemainder) {
@@ -45,7 +45,7 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s
4545
list.reserve(numOutputs);
4646

4747
int start_idx = 0;
48-
for (int i = 0; i < numOutputs; i++) {
48+
for (int64_t i = 0; i < numOutputs; i++) {
4949
at::Tensor indices = torch::arange(start_idx, start_idx + sizes[i], 1).to(torch::kI32);
5050
auto indicesTensor = tensor_to_const(ctx, indices);
5151

core/conversion/var/Var.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,10 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
9494
"Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name());
9595

9696
nvinfer1::ITensor* out;
97-
auto weights = converters::Weights();
97+
9898
if (isIValue()) {
9999
if (ptr_.ivalue->isTensor()) {
100-
auto tensor = ptr_.ivalue->toTensor();
101-
if (tensor.scalar_type() == at::kLong) {
102-
weights = converters::Weights(ctx, tensor.toType(at::kInt));
103-
} else if (tensor.scalar_type() == at::kDouble) {
104-
weights = converters::Weights(ctx, tensor.toType(at::kFloat));
105-
} else {
106-
weights = converters::Weights(ctx, tensor);
107-
}
100+
auto weights = converters::Weights(ctx, ptr_.ivalue->toTensor());
108101

109102
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
110103
TRTORCH_CHECK(const_layer, "Unable to freeze tensor into constant layer");

0 commit comments

Comments
 (0)