|
11 | 11 |
|
12 | 12 | #include "mlir/Dialect/Utils/StaticValueUtils.h"
|
13 | 13 | #include "mlir/IR/BuiltinDialect.h"
|
| 14 | +#include "mlir/IR/BuiltinTypes.h" |
14 | 15 | #include "mlir/Transforms/DialectConversion.h"
|
15 | 16 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
16 | 17 | #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
@@ -2962,6 +2963,56 @@ class DecomposeAten_LogSoftmaxOp : public OpRewritePattern<Aten_LogSoftmaxOp> {
|
2962 | 2963 | };
|
2963 | 2964 | } // namespace
|
2964 | 2965 |
|
| 2966 | +// Decompose AtenLogCumsumExpOp into: AtenExpOp, |
| 2967 | +// AtenCumsumOp and AtenLogOp |
| 2968 | +// logcumsumexp(x)[i][j] = log(sum_{k=0}^{j} exp(x[i][k])) |
| 2969 | +namespace { |
| 2970 | +class DecomposeAtenLogCumsumExpOp |
| 2971 | + : public OpRewritePattern<AtenLogcumsumexpOp> { |
| 2972 | +public: |
| 2973 | + using OpRewritePattern<AtenLogcumsumexpOp>::OpRewritePattern; |
| 2974 | + LogicalResult matchAndRewrite(AtenLogcumsumexpOp op, |
| 2975 | + PatternRewriter &rewriter) const override { |
| 2976 | + Location loc = op.getLoc(); |
| 2977 | + Value input = op.getSelf(); |
| 2978 | + |
| 2979 | + auto inputType = dyn_cast<BaseTensorType>(input.getType()); |
| 2980 | + auto resultType = dyn_cast<BaseTensorType>(op.getType()); |
| 2981 | + |
| 2982 | + if (!inputType || !inputType.hasDtype()) |
| 2983 | + return rewriter.notifyMatchFailure(op, "input should have dtype."); |
| 2984 | + |
| 2985 | + if (isa<mlir::IntegerType>(inputType.getDtype())) |
| 2986 | + return rewriter.notifyMatchFailure(op, "integer dtype is not allowed."); |
| 2987 | + |
| 2988 | + // TODO: support complex type in future. |
| 2989 | + if (isa<mlir::ComplexType>(inputType.getDtype())) |
| 2990 | + return rewriter.notifyMatchFailure(op, |
| 2991 | + "doesn't support complex type now"); |
| 2992 | + |
| 2993 | + if (!inputType.hasSizes()) |
| 2994 | + return rewriter.notifyMatchFailure(op, "input should have known size."); |
| 2995 | + |
| 2996 | + int64_t inputRank = inputType.getSizes().size(); |
| 2997 | + int64_t dim; |
| 2998 | + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) |
| 2999 | + return rewriter.notifyMatchFailure( |
| 3000 | + op, "Unimplemented: Only constant dim value is supported."); |
| 3001 | + dim = toPositiveDim(dim, inputRank); |
| 3002 | + if (!isValidDim(dim, inputRank)) |
| 3003 | + return rewriter.notifyMatchFailure(op, "invalid dim."); |
| 3004 | + |
| 3005 | + Value dtypeVal = |
| 3006 | + getDtypeIntValueForType(rewriter, loc, inputType.getDtype()); |
| 3007 | + Value expInput = rewriter.create<AtenExpOp>(loc, resultType, input); |
| 3008 | + Value cumsum = rewriter.create<AtenCumsumOp>(loc, resultType, expInput, |
| 3009 | + op.getDim(), dtypeVal); |
| 3010 | + rewriter.replaceOpWithNewOp<AtenLogOp>(op, resultType, cumsum); |
| 3011 | + return success(); |
| 3012 | + } |
| 3013 | +}; |
| 3014 | +} // namespace |
| 3015 | + |
2965 | 3016 | namespace {
|
2966 | 3017 | class DecomposeAtenLogSigmoidOp : public OpRewritePattern<AtenLogSigmoidOp> {
|
2967 | 3018 | public:
|
@@ -12114,6 +12165,7 @@ class DecomposeComplexOpsPass
|
12114 | 12165 | addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
|
12115 | 12166 | addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
|
12116 | 12167 | addPatternIfTargetOpIsIllegal<DecomposeAtenLogSigmoidOp>(patterns);
|
| 12168 | + addPatternIfTargetOpIsIllegal<DecomposeAtenLogCumsumExpOp>(patterns); |
12117 | 12169 | addPatternIfTargetOpIsIllegal<DecomposeAtenLogAddExpOp>(patterns);
|
12118 | 12170 | addPatternIfTargetOpIsIllegal<DecomposeAtenLogAddExp2Op>(patterns);
|
12119 | 12171 | addPatternIfTargetOpIsIllegal<DecomposeAtenHardshrinkOp>(patterns);
|
|
0 commit comments