Skip to content

Commit 88717f2

Browse files
committed
fix: clamp start and end to [0, input_dim] in aten::slice
Signed-off-by: inocsin <[email protected]>
1 parent b229c37 commit 88717f2

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,24 @@ nvinfer1::ITensor* clamp(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1
219219
return min_itensor;
220220
}
221221

222+
// clamp x to [0, input_dim]
223+
nvinfer1::ITensor* clamp_to_input_dim(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* x,
224+
nvinfer1::ITensor* input_dim) {
225+
auto nbdims = input_dim->getDimensions().d[0];
226+
auto zero = torch::zeros({nbdims}).to(torch::kI32);
227+
auto zero_itensor = toITensor(ctx, n, &zero);
228+
auto one = torch::ones({nbdims}).to(torch::kI32);
229+
auto one_itensor = toITensor(ctx, n, &one);
230+
auto upper_bound_layer = ctx->net->addElementWise(*input_dim, *one_itensor, nvinfer1::ElementWiseOperation::kSUB);
231+
auto upper_bound = upper_bound_layer->getOutput(0);
232+
auto max_layer = ctx->net->addElementWise(*x, *zero_itensor, nvinfer1::ElementWiseOperation::kMAX);
233+
auto max_itensor = max_layer->getOutput(0);
234+
auto min_layer = ctx->net->addElementWise(*max_itensor, *upper_bound, nvinfer1::ElementWiseOperation::kMIN);
235+
auto min_itensor = min_layer->getOutput(0);
236+
return min_itensor;
237+
}
238+
239+
222240
// return indices < 0 ? inputDims + indices : indices
223241
nvinfer1::ITensor* bump_if_negtive(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* input_dim,
224242
nvinfer1::ITensor* indices) {
@@ -238,8 +256,10 @@ nvinfer1::ITensor* bump_if_negtive(ConversionCtx* ctx, const torch::jit::Node* n
238256
void update_start_and_end(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in_shape,
239257
nvinfer1::ITensor* in_start, nvinfer1::ITensor* in_end,
240258
nvinfer1::ITensor** out_start, nvinfer1::ITensor** out_end) {
241-
*out_start = bump_if_negtive(ctx, n, in_shape, in_start);
242-
*out_end = bump_if_negtive(ctx, n, in_shape, in_end);
259+
auto start = bump_if_negtive(ctx, n, in_shape, in_start);
260+
*out_start = clamp_to_input_dim(ctx, n, start, in_shape);
261+
auto end = bump_if_negtive(ctx, n, in_shape, in_end);
262+
*out_end = clamp_to_input_dim(ctx, n, end, in_shape);
243263
}
244264

245265
bool is_dynamic_shape(nvinfer1::ITensor* tensor) {

tests/core/conversion/converters/test_select.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,35 @@ TEST(Converters, ATenSliceDynamicBatchConvertsCorrectly) {
394394
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
395395
}
396396

397+
TEST(Converters, ATenSliceDynamicBatchLargeEndConvertsCorrectly) {
398+
const auto graph = R"IR(
399+
graph(%x.1 : Tensor):
400+
%2 : None = prim::Constant()
401+
%dim : int = prim::Constant[value=0]()
402+
%start : int = prim::Constant[value=1]()
403+
%end : int = prim::Constant[value=99999]()
404+
%step : int = prim::Constant[value=2]()
405+
%9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step)
406+
return (%9))IR";
407+
408+
auto g = std::make_shared<torch::jit::Graph>();
409+
410+
torch::jit::parseIR(graph, g.get());
411+
412+
auto in = at::randint(1, 10, {16, 32}, {at::kCUDA});
413+
414+
auto jit_in = at::clone(in);
415+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
416+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
417+
418+
auto trt_in = at::clone(in);
419+
// dynamic shape in batch
420+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true);
421+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
422+
423+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
424+
}
425+
397426
TEST(Converters, ATenSliceDynamicNegStartBatchConvertsCorrectly) {
398427
const auto graph = R"IR(
399428
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)