Skip to content

Commit b7c12a8

Browse files
authored
Updating slice size calculation (#322)
1 parent bd8b630 commit b7c12a8

File tree

1 file changed

+15
-36
lines changed

1 file changed

+15
-36
lines changed

builtin_op_importers.cpp

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,12 +1594,12 @@ NodeImportResult staticInputSliceHelper(IImporterContext* ctx, nvinfer1::ITensor
15941594
std::tie(start, end, stride) = sliceBounds.at(axis);
15951595

15961596
end = std::min(handleNegativeIndex(end), static_cast<int64_t>(shape.d[axis]));
1597-
15981597
starts.d[axis] = handleNegativeIndex(start);
15991598
strides.d[axis] = handleNegativeIndex(stride);
1600-
int64_t sliceSize = (end - starts.d[axis]) / strides.d[axis];
1601-
// Add 1 to slice size since int division floors the result.
1602-
sizes.d[axis] = (end - starts.d[axis]) % strides.d[axis] == 0 ? sliceSize : sliceSize + 1;
1599+
// size = ceil((end - start) / stride). All values have been converted, so these casts to float should be OK.
1600+
int64_t sliceSize = static_cast<int64_t>(std::ceil((end - starts.d[axis]) / static_cast<float>(strides.d[axis])));
1601+
ASSERT(sliceSize > 0 && "TensorRT does not support 0 sized slices!", ErrorCode::kUNSUPPORTED_NODE);
1602+
sizes.d[axis] = sliceSize;
16031603
}
16041604
else
16051605
{
@@ -1657,22 +1657,17 @@ NodeImportResult dynamicInputSliceHelper(IImporterContext* ctx, nvinfer1::ITenso
16571657
endIndex = getAxisLength(ctx, &tensor, i, indexShape);
16581658
strideIndices.emplace_back(handleNegativeIndex(1));
16591659
}
1660-
// size[i] = min(end[i], getAxisLength(ctx, &tensor, i, indexShape)) - start[i]) / stride[i]
1661-
sizeIndices.emplace_back(
1662-
ctx->network()->addElementWise(
1663-
*ctx->network()->addElementWise(
1664-
*ctx->network()->addElementWise(
1665-
*endIndex,
1666-
*getAxisLength(ctx, &tensor, i, indexShape),
1667-
nvinfer1::ElementWiseOperation::kMIN
1668-
)->getOutput(0),
1669-
*startIndices.back(),
1670-
nvinfer1::ElementWiseOperation::kSUB
1671-
)->getOutput(0),
1672-
*strideIndices.back(),
1673-
nvinfer1::ElementWiseOperation::kDIV
1674-
)->getOutput(0)
1675-
);
1660+
auto* zeroTensor = addConstantScalar<int32_t>(ctx, 0, ::ONNX_NAMESPACE::TensorProto::INT32)->getOutput(0);
1661+
// Adapt the size caluclation into TensorRT supported layers.
1662+
// size[i] = -(floor_div(-(min(end[i], getAxisLength(ctx, &tensor, i, indexShape)) - start[i]), stride[i]))
1663+
auto* end = ctx->network()->addElementWise(*endIndex, *getAxisLength(ctx,&tensor,i, indexShape), nvinfer1::ElementWiseOperation::kMIN)->getOutput(0);
1664+
auto* diff = ctx->network()->addElementWise(*end, *startIndices.back(), nvinfer1::ElementWiseOperation::kSUB)->getOutput(0);
1665+
// Negate the diff
1666+
diff = ctx->network()->addElementWise(*zeroTensor, *diff, nvinfer1::ElementWiseOperation::kSUB)->getOutput(0);
1667+
auto* size = ctx->network()->addElementWise(*diff, *strideIndices.back(), nvinfer1::ElementWiseOperation::kFLOOR_DIV)->getOutput(0);
1668+
// Negate the size
1669+
size = ctx->network()->addElementWise(*zeroTensor, *size, nvinfer1::ElementWiseOperation::kSUB)->getOutput(0);
1670+
sizeIndices.emplace_back(size);
16761671
}
16771672

16781673
nvinfer1::ITensor* startsTensor = ctx->network()->addConcatenation(startIndices.data(), startIndices.size())->getOutput(0);
@@ -1690,7 +1685,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Slice)
16901685
{
16911686
// If opset version >= 10 slice paramerters are weights instead of attributes
16921687
nvinfer1::ITensor& tensor = convertToTensor(inputs.at(0), ctx);
1693-
nvinfer1::Dims dims = tensor.getDimensions();
16941688
std::vector<int64_t> starts;
16951689
std::vector<int64_t> ends;
16961690
std::vector<int64_t> axes;
@@ -1741,21 +1735,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Slice)
17411735

17421736
for (size_t i = 0; i < axes.size(); ++i)
17431737
{
1744-
// Do a sanity check here that the combination of starts, ends, and steps does not cause a size 0 for non-dynamic dimensions.
1745-
auto axis = axes.at(i);
1746-
const auto handleNegativeIndex = [&axis, &dims](int64_t index) -> int64_t
1747-
{
1748-
return (index < 0) ? (dims.d[axis] + index) : index;
1749-
};
1750-
1751-
if (dims.d[axis] != -1)
1752-
{
1753-
int64_t startsVal = handleNegativeIndex(starts.at(i));
1754-
int64_t endsVal = handleNegativeIndex(ends.at(i));
1755-
ASSERT((std::min(static_cast<int64_t>(dims.d[axis]), endsVal) - startsVal) / steps.at(i) != 0
1756-
&& "TensorRT does not support size 0 slices!", ErrorCode::kUNSUPPORTED_NODE);
1757-
}
1758-
17591738
sliceBounds[axes.at(i)] = std::make_tuple(starts.at(i), ends.at(i), steps.at(i));
17601739
}
17611740

0 commit comments

Comments
 (0)