@@ -307,37 +307,38 @@ auto select_registrations TORCHTRT_UNUSED =
307
307
{" aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)" ,
308
308
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
309
309
auto in = args[0 ].ITensorOrFreeze (ctx);
310
- auto axis = args[1 ].unwrapToInt ();
311
- auto maxDim = static_cast <int64_t >(in->getDimensions ().d [axis]);
310
+ int axis = args[1 ].unwrapToInt ();
311
+ int maxDim = static_cast <int32_t >(in->getDimensions ().d [axis]);
312
312
bool dynamic_shape = is_dynamic_shape (in);
313
313
auto input_dim = in->getDimensions ();
314
314
// add Shape Tensor
315
315
auto ishape_layer = ctx->net ->addShape (*in);
316
316
auto ishape_tensor = ishape_layer->getOutput (0 ); // input shape
317
317
318
- auto startIdx = 0 ;
318
+ int startIdx = 0 ;
319
319
auto startIdxIVal = args[2 ].IValue ();
320
320
if (!startIdxIVal->isNone ()) {
321
- startIdx = startIdxIVal->toInt ();
321
+ startIdx = std::min (( int64_t )std::numeric_limits< int32_t >:: max (), startIdxIVal->toInt () );
322
322
}
323
323
// Handle case when given tensor index is negative
324
324
if (maxDim > 0 ) { // only for static shape
325
325
startIdx = (startIdx < 0 ) ? (maxDim + startIdx) : startIdx;
326
326
}
327
327
328
328
// Bound the end index to input tensor dimensions at specified axis
329
- auto endIdx = maxDim; // -1 for dynamic shape
329
+ int endIdx = maxDim; // -1 for dynamic shape
330
330
auto endIdxIVal = args[3 ].IValue ();
331
331
if (!endIdxIVal->isNone ()) {
332
- endIdx = maxDim == -1 ? endIdxIVal->toInt () : std::min (endIdxIVal->toInt (), maxDim);
332
+ int truncate_value = std::min ((int64_t )std::numeric_limits<int32_t >::max (), endIdxIVal->toInt ());
333
+ endIdx = maxDim == -1 ? truncate_value : std::min (truncate_value, maxDim);
333
334
}
334
335
if (maxDim > 0 ) {
335
336
endIdx = (endIdx < 0 ) ? (maxDim + endIdx) : endIdx;
336
337
}
337
- auto step = args[4 ].unwrapToInt ();
338
+ int step = args[4 ].unwrapToInt ();
338
339
339
340
// update start, end, stride for static shape
340
- auto nbdims = in->getDimensions ().nbDims ;
341
+ int nbdims = in->getDimensions ().nbDims ;
341
342
nvinfer1::Dims start_, size_, stride_;
342
343
start_.nbDims = nbdims;
343
344
size_.nbDims = nbdims;
0 commit comments