@@ -229,26 +229,45 @@ Status applyLegacyBinaryOpBroadcasting(IImporterContext* ctx,
229229NodeImportResult argMinMaxHelper (IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node,
230230 std::vector<TensorOrWeights>& inputs, nvinfer1::TopKOperation op)
231231{
232- nvinfer1::ITensor& tensor = convertToTensor (inputs.at (0 ), ctx);
233- ASSERT (tensor.getType () != nvinfer1::DataType::kINT32 , ErrorCode::kUNSUPPORTED_NODE );
232+ nvinfer1::ITensor* tensorPtr = &convertToTensor (inputs.at (0 ), ctx);
233+ ASSERT (tensorPtr->getType () != nvinfer1::DataType::kINT32 , ErrorCode::kUNSUPPORTED_NODE );
234+
235+ // Support 1D argMin/argMax
236+ bool needToExpandDims = (tensorPtr->getDimensions ().nbDims == 1 );
237+ if (needToExpandDims)
238+ {
239+ // Expand dims from 1D to 2D
240+ std::vector<int > axes{1 };
241+ tensorPtr = unsqueezeTensor (ctx, *tensorPtr, axes);
242+ ASSERT (tensorPtr, ErrorCode::kUNSUPPORTED_NODE );
243+ }
234244 // Get attributes.
235245 OnnxAttrs attrs (node);
236246 int keepdims = attrs.get (" keepdims" , 1 );
237247 int axis = attrs.get (" axis" , 0 );
238248
239249 // Insert a TopK layer with k set to 1.
240- int nbDims = tensor. getDimensions ().nbDims ;
250+ int nbDims = tensorPtr-> getDimensions ().nbDims ;
241251 TRT_CHECK (convert_axis (axis, nbDims));
242252
243253 uint32_t axisMask = 1 << axis;
244- nvinfer1::ITopKLayer* layer = ctx->network ()->addTopK (tensor , op, 1 , axisMask);
254+ nvinfer1::ITopKLayer* layer = ctx->network ()->addTopK (*tensorPtr , op, 1 , axisMask);
245255 ASSERT (layer, ErrorCode::kUNSUPPORTED_NODE );
246256 // We don't care about the TopK values, just the indices.
247257 nvinfer1::ITensor* indices = layer->getOutput (1 );
248258 indices->setType (nvinfer1::DataType::kINT32 );
259+
260+ // Squeeze back to 1D if applicable
261+ if (needToExpandDims)
262+ {
263+ std::vector<int > axes{1 };
264+ indices = squeezeTensor (ctx, *indices, axes);
265+ ASSERT (indices, ErrorCode::kUNSUPPORTED_NODE );
266+ }
267+
268+ // The default behavior of the TopK layer is to keepdims.
249269 if (keepdims)
250270 {
251- // The default behavior of the TopK layer is to keepdims.
252271 return {{indices}};
253272 }
254273 else
@@ -1177,9 +1196,11 @@ nvinfer1::ITensor* squeezeStaticTensor(IImporterContext* ctx, nvinfer1::ITensor&
11771196 std::set<int > axesSet (axes.begin (), axes.end ());
11781197 std::vector<int > shape{dims.d , dims.d + dims.nbDims };
11791198
1199+ int axisCount = 0 ;
11801200 for (const auto & axis : axesSet)
11811201 {
1182- shape.erase (shape.begin () + axis);
1202+ shape.erase (shape.begin () + axis - axisCount);
1203+ axisCount++;
11831204 }
11841205
11851206 nvinfer1::Dims newShape{dims.nbDims - static_cast <int >(axesSet.size ())};
0 commit comments