1212#include " torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
1313
1414#include " ../PassDetail.h"
15+ #include " mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16+ #include " mlir/Dialect/Tensor/IR/Tensor.h"
1517#include " mlir/Dialect/Tosa/IR/TosaOps.h"
1618#include " mlir/Dialect/Traits.h"
1719#include " mlir/IR/Matchers.h"
@@ -404,6 +406,93 @@ class ConvertAtenAllDimsReductionOp
404406 }
405407};
406408
409+ template <>
410+ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
411+ AtenArgmaxOp op, OpAdaptor adaptor,
412+ ConversionPatternRewriter &rewriter) const {
413+
414+ Value self = adaptor.self ();
415+ auto selfTy = self.getType ().template cast <RankedTensorType>();
416+
417+ if (!selfTy)
418+ return op.emitError (" Only ranked tensor types supported in TOSA argmax" );
419+
420+ int64_t reduceDim;
421+ if (!matchPattern (op.dim (), m_TorchConstantInt (&reduceDim))) {
422+ // NoneType indicates reduce on all dims
423+ reduceDim = -1 ;
424+ }
425+
426+ bool keepDim = false ;
427+ if (!matchPattern (op.keepdim (), m_TorchConstantBool (&keepDim)))
428+ return rewriter.notifyMatchFailure (
429+ op, " non-const keepdim parameter unsupported" );
430+
431+ auto resultTy = getTypeConverter ()
432+ ->convertType (op.getResult ().getType ())
433+ .cast <RankedTensorType>();
434+ auto outputETy = resultTy.getElementType ();
435+
436+ // Create a single instance of tosa.argmax.
437+ // Multiple dims require chained construct.
438+ auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value {
439+ auto inputTy = input.getType ().cast <RankedTensorType>();
440+ auto inputShape = inputTy.getShape ();
441+ SmallVector<int64_t > outputShapeArr = {};
442+ int32_t i = 0 ;
443+
444+ for (auto &dim : inputShape) {
445+ if (i++ != reduceDim) {
446+ outputShapeArr.push_back (dim);
447+ } else {
448+ if (keepDim)
449+ outputShapeArr.push_back (1 );
450+ }
451+ }
452+
453+ // Tosa argmax output is i32, while Torch backend mandates i64.
454+ auto outputReduceTy = RankedTensorType::get (
455+ ArrayRef<int64_t >(outputShapeArr), rewriter.getI32Type ());
456+ auto reduceDimAttr =
457+ rewriter.getIntegerAttr (rewriter.getI64Type (), reduceDim);
458+ return rewriter
459+ .create <tosa::ArgMaxOp>(op->getLoc (),
460+ getTypeConverter ()->convertType (outputReduceTy),
461+ input, reduceDimAttr)
462+ .getResult ();
463+ };
464+
465+ // Convert the final index to i64 for backend finalization, However, i64
466+ // is not a defined type for tosa.cast, so using arith.extsi instead.
467+ auto castToInt64 = [&](Value result) -> LogicalResult {
468+ auto resTy = result.getType ().cast <ShapedType>();
469+ if (!resTy)
470+ return op.emitError (" Argmax: Result is not a shaped type" );
471+
472+ auto resShape = resTy.getShape ();
473+ auto outTy =
474+ RankedTensorType::get (resShape, outputETy); // rewriter.getI64Type());
475+
476+ rewriter.replaceOpWithNewOp <arith::ExtSIOp>(
477+ op, getTypeConverter ()->convertType (outTy), result);
478+
479+ return success ();
480+ };
481+
482+ if (reduceDim == -1 ) { // reducing on all dims
483+ Value input = self;
484+ for (int dim = 0 ; dim < selfTy.getRank (); dim++) {
485+ // progressively reduce each 0-th dim
486+ input = buildArgmax (0 , input);
487+ }
488+ return castToInt64 (input);
489+ } else {
490+ return castToInt64 (buildArgmax (reduceDim, self));
491+ }
492+
493+ return success ();
494+ }
495+
407496} // namespace
408497
409498// -----------------------------------------------------------------------------
@@ -415,13 +504,16 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
415504public:
416505 void getDependentDialects (DialectRegistry ®istry) const override {
417506 registry.insert <tosa::TosaDialect>();
507+ registry.insert <tensor::TensorDialect>();
508+ registry.insert <arith::ArithmeticDialect>();
418509 TorchConversion::getBackendTypeConversionDependentDialects (registry);
419510 }
420511
421512 void runOnOperation () override {
422513 MLIRContext *context = &getContext ();
423514 ConversionTarget target (*context);
424- target.addLegalDialect <tosa::TosaDialect>();
515+ target.addLegalDialect <tosa::TosaDialect, tensor::TensorDialect,
516+ arith::ArithmeticDialect>();
425517
426518 TypeConverter typeConverter;
427519 typeConverter.addConversion ([](Type type) { return type; });
@@ -491,6 +583,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
491583 INSERT_ATENOP_PATTERN (AtenReluOp);
492584 INSERT_ATENOP_PATTERN (AtenMulTensorOp);
493585 INSERT_ATENOP_PATTERN (AtenDivTensorOp);
586+ INSERT_ATENOP_PATTERN (AtenArgmaxOp);
494587#undef INSERT_ATENOP_PATTERN
495588
496589 if (failed (applyPartialConversion (getOperation (), target,
0 commit comments