Skip to content

Commit 8716c9b

Browse files
authored
Bug fixes for argmin/argmax, maxglobalpooling and squeeze (#331)
1 parent 2066f53 commit 8716c9b

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

builtin_op_importers.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,7 @@ DEFINE_BUILTIN_OP_IMPORTER(GlobalMaxPool)
889889
static_cast<nvinfer1::Dims>(nvinfer1::Dims2{dims.d[2], dims.d[3]}) :
890890
static_cast<nvinfer1::Dims>(nvinfer1::Dims3{dims.d[2], dims.d[3], dims.d[4]});
891891
ASSERT(!isDynamic(kernelSize) && "Cannot run GlobalMaxPool on an input with dynamic spatial dimensions!", ErrorCode::kUNSUPPORTED_NODE);
892-
RETURN_FIRST_OUTPUT(ctx->network()->addPoolingNd(tensor, nvinfer1::PoolingType::kAVERAGE, kernelSize));
892+
RETURN_FIRST_OUTPUT(ctx->network()->addPoolingNd(tensor, nvinfer1::PoolingType::kMAX, kernelSize));
893893
}
894894

895895
DEFINE_BUILTIN_OP_IMPORTER(HardSigmoid)

onnx2trt_utils.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,26 +229,45 @@ Status applyLegacyBinaryOpBroadcasting(IImporterContext* ctx,
229229
NodeImportResult argMinMaxHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node,
230230
std::vector<TensorOrWeights>& inputs, nvinfer1::TopKOperation op)
231231
{
232-
nvinfer1::ITensor& tensor = convertToTensor(inputs.at(0), ctx);
233-
ASSERT(tensor.getType() != nvinfer1::DataType::kINT32, ErrorCode::kUNSUPPORTED_NODE);
232+
nvinfer1::ITensor* tensorPtr = &convertToTensor(inputs.at(0), ctx);
233+
ASSERT(tensorPtr->getType() != nvinfer1::DataType::kINT32, ErrorCode::kUNSUPPORTED_NODE);
234+
235+
// Support 1D argMin/argMax
236+
bool needToExpandDims = (tensorPtr->getDimensions().nbDims == 1);
237+
if (needToExpandDims)
238+
{
239+
// Expand dims from 1D to 2D
240+
std::vector<int> axes{1};
241+
tensorPtr = unsqueezeTensor(ctx, *tensorPtr, axes);
242+
ASSERT(tensorPtr, ErrorCode::kUNSUPPORTED_NODE);
243+
}
234244
// Get attributes.
235245
OnnxAttrs attrs(node);
236246
int keepdims = attrs.get("keepdims", 1);
237247
int axis = attrs.get("axis", 0);
238248

239249
// Insert a TopK layer with k set to 1.
240-
int nbDims = tensor.getDimensions().nbDims;
250+
int nbDims = tensorPtr->getDimensions().nbDims;
241251
TRT_CHECK(convert_axis(axis, nbDims));
242252

243253
uint32_t axisMask = 1 << axis;
244-
nvinfer1::ITopKLayer* layer = ctx->network()->addTopK(tensor, op, 1, axisMask);
254+
nvinfer1::ITopKLayer* layer = ctx->network()->addTopK(*tensorPtr, op, 1, axisMask);
245255
ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE);
246256
// We don't care about the TopK values, just the indices.
247257
nvinfer1::ITensor* indices = layer->getOutput(1);
248258
indices->setType(nvinfer1::DataType::kINT32);
259+
260+
// Squeeze back to 1D if applicable
261+
if (needToExpandDims)
262+
{
263+
std::vector<int> axes{1};
264+
indices = squeezeTensor(ctx, *indices, axes);
265+
ASSERT(indices, ErrorCode::kUNSUPPORTED_NODE);
266+
}
267+
268+
// The default behavior of the TopK layer is to keepdims.
249269
if (keepdims)
250270
{
251-
// The default behavior of the TopK layer is to keepdims.
252271
return {{indices}};
253272
}
254273
else
@@ -1177,9 +1196,11 @@ nvinfer1::ITensor* squeezeStaticTensor(IImporterContext* ctx, nvinfer1::ITensor&
11771196
std::set<int> axesSet(axes.begin(), axes.end());
11781197
std::vector<int> shape{dims.d, dims.d + dims.nbDims};
11791198

1199+
int axisCount = 0;
11801200
for (const auto& axis : axesSet)
11811201
{
1182-
shape.erase(shape.begin() + axis);
1202+
shape.erase(shape.begin() + axis - axisCount);
1203+
axisCount++;
11831204
}
11841205

11851206
nvinfer1::Dims newShape{dims.nbDims - static_cast<int>(axesSet.size())};

0 commit comments

Comments
 (0)