1+ // ===-- DecomposeAggregatedOps.cpp - Decompose Aggregated Ops ---*- C++ -*-===//
2+ //
3+ // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+ // See https://llvm.org/LICENSE.txt for license information.
5+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+ //
7+ // ===----------------------------------------------------------------------===//
8+
9+ #include " gc/Transforms/Passes.h"
10+ #include " mlir/Dialect/Func/IR/FuncOps.h"
11+ #include " mlir/Dialect/Linalg/IR/Linalg.h"
12+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
14+ using namespace mlir ;
15+ namespace mlir {
16+ namespace gc {
17+ #define GEN_PASS_DEF_DECOMPOSEAGGREGATEDOPS
18+ #include " gc/Transforms/Passes.h.inc"
19+ } // namespace gc
20+ } // namespace mlir
21+
22+ namespace {
23+
24+ struct DecomposeAggregateOpsImpl : public OpRewritePattern <linalg::SoftmaxOp> {
25+ using OpRewritePattern<linalg::SoftmaxOp>::OpRewritePattern;
26+
27+ LogicalResult matchAndRewrite (linalg::SoftmaxOp softmaxOp,
28+ PatternRewriter &rewriter) const override {
29+ auto decomposableOp =
30+ cast<linalg::AggregatedOpInterface>(softmaxOp.getOperation ());
31+ FailureOr<SmallVector<Value>> maybeNewResult =
32+ decomposableOp.decomposeOperation (rewriter);
33+ if (failed (maybeNewResult))
34+ return failure ();
35+ rewriter.replaceOp (softmaxOp, *maybeNewResult);
36+ return success ();
37+ }
38+ };
39+
40+ struct DecomposeAggregatedOps
41+ : public gc::impl::DecomposeAggregatedOpsBase<DecomposeAggregatedOps> {
42+ void runOnOperation () override {
43+ RewritePatternSet patterns (getOperation ().getContext ());
44+ patterns.add <DecomposeAggregateOpsImpl>(patterns.getContext ());
45+ (void )applyPatternsAndFoldGreedily (getOperation (), std::move (patterns));
46+ }
47+ };
48+
49+ } // namespace
0 commit comments