@@ -219,6 +219,24 @@ nvinfer1::ITensor* clamp(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1
219
219
return min_itensor;
220
220
}
221
221
222
+ // clamp x to [0, input_dim]
223
+ nvinfer1::ITensor* clamp_to_input_dim (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* x,
224
+ nvinfer1::ITensor* input_dim) {
225
+ auto nbdims = input_dim->getDimensions ().d [0 ];
226
+ auto zero = torch::zeros ({nbdims}).to (torch::kI32 );
227
+ auto zero_itensor = toITensor (ctx, n, &zero);
228
+ auto one = torch::ones ({nbdims}).to (torch::kI32 );
229
+ auto one_itensor = toITensor (ctx, n, &one);
230
+ auto upper_bound_layer = ctx->net ->addElementWise (*input_dim, *one_itensor, nvinfer1::ElementWiseOperation::kSUB );
231
+ auto upper_bound = upper_bound_layer->getOutput (0 );
232
+ auto max_layer = ctx->net ->addElementWise (*x, *zero_itensor, nvinfer1::ElementWiseOperation::kMAX );
233
+ auto max_itensor = max_layer->getOutput (0 );
234
+ auto min_layer = ctx->net ->addElementWise (*max_itensor, *upper_bound, nvinfer1::ElementWiseOperation::kMIN );
235
+ auto min_itensor = min_layer->getOutput (0 );
236
+ return min_itensor;
237
+ }
238
+
239
+
222
240
// return indices < 0 ? inputDims + indices : indices
223
241
nvinfer1::ITensor* bump_if_negtive (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* input_dim,
224
242
nvinfer1::ITensor* indices) {
@@ -238,8 +256,10 @@ nvinfer1::ITensor* bump_if_negtive(ConversionCtx* ctx, const torch::jit::Node* n
238
256
void update_start_and_end (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in_shape,
239
257
nvinfer1::ITensor* in_start, nvinfer1::ITensor* in_end,
240
258
nvinfer1::ITensor** out_start, nvinfer1::ITensor** out_end) {
241
- *out_start = bump_if_negtive (ctx, n, in_shape, in_start);
242
- *out_end = bump_if_negtive (ctx, n, in_shape, in_end);
259
+ auto start = bump_if_negtive (ctx, n, in_shape, in_start);
260
+ *out_start = clamp_to_input_dim (ctx, n, start, in_shape);
261
+ auto end = bump_if_negtive (ctx, n, in_shape, in_end);
262
+ *out_end = clamp_to_input_dim (ctx, n, end, in_shape);
243
263
}
244
264
245
265
bool is_dynamic_shape (nvinfer1::ITensor* tensor) {
0 commit comments