Skip to content

Commit 829cf8a

Browse files
authored
[tosa] Implement Argmax support (#485)
Signed-off-by: Suraj Sudhir <[email protected]>
1 parent d13bb0e commit 829cf8a

File tree

2 files changed

+96
-1
lines changed

2 files changed

+96
-1
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
1313

1414
#include "../PassDetail.h"
15+
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1517
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1618
#include "mlir/Dialect/Traits.h"
1719
#include "mlir/IR/Matchers.h"
@@ -404,6 +406,93 @@ class ConvertAtenAllDimsReductionOp
404406
}
405407
};
406408

409+
template <>
410+
LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
411+
AtenArgmaxOp op, OpAdaptor adaptor,
412+
ConversionPatternRewriter &rewriter) const {
413+
414+
Value self = adaptor.self();
415+
auto selfTy = self.getType().template cast<RankedTensorType>();
416+
417+
if (!selfTy)
418+
return op.emitError("Only ranked tensor types supported in TOSA argmax");
419+
420+
int64_t reduceDim;
421+
if (!matchPattern(op.dim(), m_TorchConstantInt(&reduceDim))) {
422+
// NoneType indicates reduce on all dims
423+
reduceDim = -1;
424+
}
425+
426+
bool keepDim = false;
427+
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim)))
428+
return rewriter.notifyMatchFailure(
429+
op, "non-const keepdim parameter unsupported");
430+
431+
auto resultTy = getTypeConverter()
432+
->convertType(op.getResult().getType())
433+
.cast<RankedTensorType>();
434+
auto outputETy = resultTy.getElementType();
435+
436+
// Create a single instance of tosa.argmax.
437+
// Multiple dims require chained construct.
438+
auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value {
439+
auto inputTy = input.getType().cast<RankedTensorType>();
440+
auto inputShape = inputTy.getShape();
441+
SmallVector<int64_t> outputShapeArr = {};
442+
int32_t i = 0;
443+
444+
for (auto &dim : inputShape) {
445+
if (i++ != reduceDim) {
446+
outputShapeArr.push_back(dim);
447+
} else {
448+
if (keepDim)
449+
outputShapeArr.push_back(1);
450+
}
451+
}
452+
453+
// Tosa argmax output is i32, while Torch backend mandates i64.
454+
auto outputReduceTy = RankedTensorType::get(
455+
ArrayRef<int64_t>(outputShapeArr), rewriter.getI32Type());
456+
auto reduceDimAttr =
457+
rewriter.getIntegerAttr(rewriter.getI64Type(), reduceDim);
458+
return rewriter
459+
.create<tosa::ArgMaxOp>(op->getLoc(),
460+
getTypeConverter()->convertType(outputReduceTy),
461+
input, reduceDimAttr)
462+
.getResult();
463+
};
464+
465+
// Convert the final index to i64 for backend finalization, However, i64
466+
// is not a defined type for tosa.cast, so using arith.extsi instead.
467+
auto castToInt64 = [&](Value result) -> LogicalResult {
468+
auto resTy = result.getType().cast<ShapedType>();
469+
if (!resTy)
470+
return op.emitError("Argmax: Result is not a shaped type");
471+
472+
auto resShape = resTy.getShape();
473+
auto outTy =
474+
RankedTensorType::get(resShape, outputETy); // rewriter.getI64Type());
475+
476+
rewriter.replaceOpWithNewOp<arith::ExtSIOp>(
477+
op, getTypeConverter()->convertType(outTy), result);
478+
479+
return success();
480+
};
481+
482+
if (reduceDim == -1) { // reducing on all dims
483+
Value input = self;
484+
for (int dim = 0; dim < selfTy.getRank(); dim++) {
485+
// progressively reduce each 0-th dim
486+
input = buildArgmax(0, input);
487+
}
488+
return castToInt64(input);
489+
} else {
490+
return castToInt64(buildArgmax(reduceDim, self));
491+
}
492+
493+
return success();
494+
}
495+
407496
} // namespace
408497

409498
// -----------------------------------------------------------------------------
@@ -415,13 +504,16 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
415504
public:
416505
void getDependentDialects(DialectRegistry &registry) const override {
417506
registry.insert<tosa::TosaDialect>();
507+
registry.insert<tensor::TensorDialect>();
508+
registry.insert<arith::ArithmeticDialect>();
418509
TorchConversion::getBackendTypeConversionDependentDialects(registry);
419510
}
420511

421512
void runOnOperation() override {
422513
MLIRContext *context = &getContext();
423514
ConversionTarget target(*context);
424-
target.addLegalDialect<tosa::TosaDialect>();
515+
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
516+
arith::ArithmeticDialect>();
425517

426518
TypeConverter typeConverter;
427519
typeConverter.addConversion([](Type type) { return type; });
@@ -491,6 +583,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
491583
INSERT_ATENOP_PATTERN(AtenReluOp);
492584
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
493585
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
586+
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
494587
#undef INSERT_ATENOP_PATTERN
495588

496589
if (failed(applyPartialConversion(getOperation(), target,

lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class VerifyTosaBackendContractPass
4242
target.addDynamicallyLegalOp<ModuleOp, FuncOp, ReturnOp>(opHasLegalTypes);
4343
// Basic scalar operations.
4444
target.addLegalDialect<tosa::TosaDialect>();
45+
target.addDynamicallyLegalOp<tensor::CastOp>(opHasLegalTypes);
46+
target.addDynamicallyLegalOp<arith::ExtSIOp>(opHasLegalTypes);
4547

4648
RewritePatternSet patterns(context);
4749
if (failed(applyFullConversion(module, target, std::move(patterns)))) {

0 commit comments

Comments
 (0)