@@ -32,7 +32,9 @@ nvinfer1::ITensor* index_layer(
3232c10::IValue dynamic_size_layer (ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
3333 LOG_DEBUG (" Using dynamic version of aten::size evaluator" );
3434 auto in = args.at (n->input (0 )).ITensorOrFreeze (ctx);
35- LOG_DEBUG (" Input dimensions: " << in->getDimensions ());
35+ auto input_dims = in->getDimensions ();
36+ LOG_DEBUG (" Input dimensions: " << input_dims);
37+
3638 auto shape_layer = ctx->net ->addShape (*in);
3739 TORCHTRT_CHECK (shape_layer, " Unable to create shape layer from node: " << *n);
3840 auto shape_1d_tensor = shape_layer->getOutput (0 );
@@ -44,15 +46,31 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw
4446 dim = dim < 0 ? dim + maxDim : dim;
4547 LOG_DEBUG (" Dimension to select: " << dim);
4648 shape_1d_tensor = index_layer (ctx, n, shape_1d_tensor, dim);
47- }
49+ LOG_DEBUG ( " Output tensor shape: " << shape_1d_tensor-> getDimensions ());
4850
49- LOG_DEBUG (" Output tensor shape: " << shape_1d_tensor->getDimensions ());
51+ auto tensor_holder = TensorContainer ();
52+ tensor_holder.hold_tensor (shape_1d_tensor);
53+ auto shape_1d_ivalue = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
5054
51- auto tensor_holder = TensorContainer ();
52- tensor_holder.hold_tensor (shape_1d_tensor);
53- auto shape_1d_ivalue = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
55+ return shape_1d_ivalue;
5456
55- return shape_1d_ivalue;
57+ } else {
58+ auto input_size = c10::impl::GenericList (c10::AnyType::get ());
59+ // Only express the dynamic dimension with a shape layer output.
60+ // The static dimensions are preserved in the input size.
61+ for (int32_t i = 0 ; i < input_dims.nbDims ; i++) {
62+ if (input_dims.d [i] == -1 ) {
63+ auto dynamic_dim_tensor = index_layer (ctx, n, shape_1d_tensor, i);
64+ auto dynamic_dim_holder = TensorContainer ();
65+ dynamic_dim_holder.hold_tensor (dynamic_dim_tensor);
66+ auto dynamic_dim_ivalue = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(dynamic_dim_holder)));
67+ input_size.emplace_back (std::move (dynamic_dim_ivalue));
68+ } else {
69+ input_size.emplace_back (input_dims.d [i]);
70+ }
71+ }
72+ return c10::IValue (input_size);
73+ }
5674}
5775
5876int64_t normalizeIndex (int64_t idx, int64_t list_size) {
0 commit comments