Skip to content

Commit de42bf0

Browse files
committed
fix: fix aten::slice when index is LONG_MAX
Signed-off-by: inocsin <[email protected]>
1 parent 6d6a347 commit de42bf0

File tree

3 files changed

+16
-14
lines changed

3 files changed

+16
-14
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ nvinfer1::ITensor* clamp(
218218
}
219219

220220
// clamp x to [0, input_dim]
221-
nvinfer1::ITensor* clamp_to_input_dim(ConversionCtx* ctx, nvinfer1::ITensor* x, nvinfer1::ITensor* input_dim) {
222-
auto nbdims = input_dim->getDimensions().d[0];
221+
nvinfer1::ITensor* clamp_to_input_dim(ConversionCtx* ctx, nvinfer1::ITensor* x, nvinfer1::ITensor* input_dim, int nbdims) {
222+
// auto nbdims = input_dim->getDimensions().d[0];
223223
auto zero = torch::zeros({nbdims}).to(torch::kI32);
224224
auto zero_itensor = tensor_to_const(ctx, zero);
225225
auto one = torch::ones({nbdims}).to(torch::kI32);
@@ -243,8 +243,7 @@ nvinfer1::ITensor* clamp_to_input_dim(ConversionCtx* ctx, nvinfer1::ITensor* x,
243243
}
244244

245245
// return indices < 0 ? inputDims + indices : indices
246-
nvinfer1::ITensor* bump_if_negtive(ConversionCtx* ctx, nvinfer1::ITensor* input_dim, nvinfer1::ITensor* indices) {
247-
auto nbdims = input_dim->getDimensions().d[0];
246+
nvinfer1::ITensor* bump_if_negtive(ConversionCtx* ctx, nvinfer1::ITensor* input_dim, nvinfer1::ITensor* indices, int nbdims) {
248247
auto zero = torch::zeros({nbdims}).to(torch::kI32);
249248
auto neg = -torch::ones({nbdims}).to(torch::kI32);
250249
auto zero_itensor = tensor_to_const(ctx, zero);
@@ -270,11 +269,12 @@ std::vector<nvinfer1::ITensor*> update_start_and_end(
270269
ConversionCtx* ctx,
271270
nvinfer1::ITensor* in_shape,
272271
nvinfer1::ITensor* in_start,
273-
nvinfer1::ITensor* in_end) {
274-
auto start = bump_if_negtive(ctx, in_shape, in_start);
275-
auto out_start = clamp_to_input_dim(ctx, start, in_shape);
276-
auto end = bump_if_negtive(ctx, in_shape, in_end);
277-
auto out_end = clamp_to_input_dim(ctx, end, in_shape);
272+
nvinfer1::ITensor* in_end,
273+
int nbdims) {
274+
auto start = bump_if_negtive(ctx, in_shape, in_start, nbdims);
275+
auto out_start = clamp_to_input_dim(ctx, start, in_shape, nbdims);
276+
auto end = bump_if_negtive(ctx, in_shape, in_end, nbdims);
277+
auto out_end = clamp_to_input_dim(ctx, end, in_shape, nbdims);
278278
std::vector<nvinfer1::ITensor*> outputs;
279279
outputs.push_back(out_start);
280280
outputs.push_back(out_end);

core/conversion/converters/converter_util.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ std::vector<nvinfer1::ITensor*> update_start_and_end(
6363
ConversionCtx* ctx,
6464
nvinfer1::ITensor* in_shape,
6565
nvinfer1::ITensor* in_start,
66-
nvinfer1::ITensor* in_end);
66+
nvinfer1::ITensor* in_end,
67+
int nbdims);
6768

6869
nvinfer1::ITensor* calculate_output_size(
6970
ConversionCtx* ctx,

core/conversion/converters/impl/select.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,8 @@ auto select_registrations TORCHTRT_UNUSED =
318318
int startIdx = 0;
319319
auto startIdxIVal = args[2].IValue();
320320
if (!startIdxIVal->isNone()) {
321-
startIdx = std::min((int64_t)std::numeric_limits<int32_t>::max(), startIdxIVal->toInt());
321+
startIdx = startIdxIVal->toInt() > std::numeric_limits<int32_t>::max() ? maxDim : startIdxIVal->toInt();
322+
startIdx = maxDim == -1 ? startIdx : std::min(startIdx, maxDim);
322323
}
323324
// Handle case when given tensor index is negative
324325
if (maxDim > 0) { // only for static shape
@@ -329,7 +330,7 @@ auto select_registrations TORCHTRT_UNUSED =
329330
int endIdx = maxDim; // -1 for dynamic shape
330331
auto endIdxIVal = args[3].IValue();
331332
if (!endIdxIVal->isNone()) {
332-
int truncate_value = std::min((int64_t)std::numeric_limits<int32_t>::max(), endIdxIVal->toInt());
333+
int truncate_value = endIdxIVal->toInt() > std::numeric_limits<int32_t>::max() ? maxDim : endIdxIVal->toInt();
333334
endIdx = maxDim == -1 ? truncate_value : std::min(truncate_value, maxDim);
334335
}
335336
if (maxDim > 0) {
@@ -373,7 +374,7 @@ auto select_registrations TORCHTRT_UNUSED =
373374
at::Tensor end_tensor = torch::zeros({nbdims}).to(torch::kI32);
374375
for (int i = 0; i < nbdims; i++) {
375376
if (i == axis) {
376-
end_tensor[i] = endIdxIVal->isNone() ? -1 : endIdx - 1;
377+
end_tensor[i] = endIdx == -1 ? -1 : endIdx - 1;
377378
} else {
378379
end_tensor[i] = input_dim.d[i] == -1 ? -1 : input_dim.d[i] - 1;
379380
}
@@ -383,7 +384,7 @@ auto select_registrations TORCHTRT_UNUSED =
383384
// update start and end
384385
nvinfer1::ITensor* out_start;
385386
nvinfer1::ITensor* out_end;
386-
auto start_end = update_start_and_end(ctx, ishape_tensor, start_itensor, end_itensor);
387+
auto start_end = update_start_and_end(ctx, ishape_tensor, start_itensor, end_itensor, nbdims);
387388
out_start = start_end[0];
388389
out_end = start_end[1];
389390

0 commit comments

Comments
 (0)