Skip to content

Commit 979d9d1

Browse files
committed
fix: support end=9223372036854775807 (default value) in aten::slice
Signed-off-by: inocsin <[email protected]>
1 parent 72580eb commit 979d9d1

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

core/conversion/converters/converter_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <map>
44
#include <string>
5+
#include <limits>
56

67
#include "core/conversion/conversionctx/ConversionCtx.h"
78
#include "core/conversion/converters/Weights.h"

core/conversion/converters/impl/select.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -307,37 +307,38 @@ auto select_registrations TORCHTRT_UNUSED =
307307
{"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
308308
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
309309
auto in = args[0].ITensorOrFreeze(ctx);
310-
auto axis = args[1].unwrapToInt();
311-
auto maxDim = static_cast<int64_t>(in->getDimensions().d[axis]);
310+
int axis = args[1].unwrapToInt();
311+
int maxDim = static_cast<int32_t>(in->getDimensions().d[axis]);
312312
bool dynamic_shape = is_dynamic_shape(in);
313313
auto input_dim = in->getDimensions();
314314
// add Shape Tensor
315315
auto ishape_layer = ctx->net->addShape(*in);
316316
auto ishape_tensor = ishape_layer->getOutput(0); // input shape
317317

318-
auto startIdx = 0;
318+
int startIdx = 0;
319319
auto startIdxIVal = args[2].IValue();
320320
if (!startIdxIVal->isNone()) {
321-
startIdx = startIdxIVal->toInt();
321+
startIdx = std::min((int64_t)std::numeric_limits<int32_t>::max(), startIdxIVal->toInt());
322322
}
323323
// Handle case when given tensor index is negative
324324
if (maxDim > 0) { // only for static shape
325325
startIdx = (startIdx < 0) ? (maxDim + startIdx) : startIdx;
326326
}
327327

328328
// Bound the end index to input tensor dimensions at specified axis
329-
auto endIdx = maxDim; // -1 for dynamic shape
329+
int endIdx = maxDim; // -1 for dynamic shape
330330
auto endIdxIVal = args[3].IValue();
331331
if (!endIdxIVal->isNone()) {
332-
endIdx = maxDim == -1 ? endIdxIVal->toInt() : std::min(endIdxIVal->toInt(), maxDim);
332+
int truncate_value = std::min((int64_t)std::numeric_limits<int32_t>::max(), endIdxIVal->toInt());
333+
endIdx = maxDim == -1 ? truncate_value : std::min(truncate_value, maxDim);
333334
}
334335
if (maxDim > 0) {
335336
endIdx = (endIdx < 0) ? (maxDim + endIdx) : endIdx;
336337
}
337-
auto step = args[4].unwrapToInt();
338+
int step = args[4].unwrapToInt();
338339

339340
// update start, end, stride for static shape
340-
auto nbdims = in->getDimensions().nbDims;
341+
int nbdims = in->getDimensions().nbDims;
341342
nvinfer1::Dims start_, size_, stride_;
342343
start_.nbDims = nbdims;
343344
size_.nbDims = nbdims;

tests/core/conversion/converters/test_select.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ TEST(Converters, ATenSliceDynamicBatchLargeEndConvertsCorrectly) {
400400
%2 : None = prim::Constant()
401401
%dim : int = prim::Constant[value=0]()
402402
%start : int = prim::Constant[value=1]()
403-
%end : int = prim::Constant[value=99999]()
403+
%end : int = prim::Constant[value=9223372036854775807]()
404404
%step : int = prim::Constant[value=2]()
405405
%9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step)
406406
return (%9))IR";

0 commit comments

Comments
 (0)