diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index e4dd458eaff84..308e39a9a51e1 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -601,7 +601,18 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [ [{ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs, attributes, ElementwiseOp::getRegionBuilder()); - }]> + }]>, + + OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs, + "ElementwiseKindAttr":$kind, + "ArrayAttr":$indexingMaps, + CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addAttribute("kind", kind); + $_state.addAttribute("indexing_maps", indexingMaps); + buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs, + attributes, ElementwiseOp::getRegionBuilder()); + }]> ]; let hasCustomAssemblyFormat = 1; diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index d96ad919b65f0..373842c9b03de 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> { let dependentDialects = ["linalg::LinalgDialect"]; } +def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> { + let summary = "Fold transform, broadcast and other ops into elementwise"; + let dependentDialects = ["linalg::LinalgDialect"]; +} + def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> { let summary = "Detensorize linalg ops"; let dependentDialects = []; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 8fdcdeff250bb..c302f6d682d69 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1710,6 +1710,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns); void populateLinalgGenericOpsSpecializationPatterns( RewritePatternSet &patterns); +/// Populates `patterns` with patterns that fold operations like +/// `linalg.transform` into elementwise op map. +void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns); + /// Linalg decompose convolutions patterns /// Populates patterns to decompose high-D convolution ops into low-D ones. diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index d18b6f8afc43b..881d9fcb4f52e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms EliminateEmptyTensors.cpp EraseUnusedOperandsAndResults.cpp FoldAddIntoDest.cpp + FoldIntoElementwise.cpp FusePadOpWithLinalgProducer.cpp Fusion.cpp Generalization.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp new file mode 100644 index 0000000000000..bdd4f6025b051 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp @@ -0,0 +1,89 @@ +//===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements folding ops such as transpose and broadcast into the +// affine maps of the elementwise op. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +#define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS +#include "mlir/Dialect/Linalg/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::linalg; + +#define DEBUG_TYPE "linalg-fold-into-elementwise" + +namespace { +struct FoldTransposePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ElementwiseOp op, + PatternRewriter &rewriter) const override { + bool changed = false; + SmallVector newIns; + SmallVector newMaps; + for (OpOperand *operand : op.getDpsInputOperands()) { + AffineMap map = op.getMatchingIndexingMap(operand); + auto transposeOp = operand->get().getDefiningOp(); + + if (!map.isIdentity() || !transposeOp) { + // push in original operand and its map. + newIns.push_back(operand->get()); + newMaps.push_back(map); + continue; + } + newIns.push_back(transposeOp.getInput()); + // push in transposeOp's inverse permutation map. + newMaps.push_back(transposeOp.getMatchingIndexingMap( + transposeOp.getDpsInputOperand(0))); + changed = true; + } + if (!changed) + return failure(); + newMaps.push_back(op.getIndexingMapsArray().back()); + + rewriter.replaceOpWithNewOp( + op, newIns, op.getDpsInits()[0], op.getKindAttr(), + rewriter.getAffineMapArrayAttr(newMaps)); + return success(); + } +}; + +struct LinalgFoldIntoElementwisePass + : public impl::LinalgFoldIntoElementwisePassBase< + LinalgFoldIntoElementwisePass> { + using impl::LinalgFoldIntoElementwisePassBase< + LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase; + + void runOnOperation() override { + llvm::outs() << "Hellow from fold into elemenwise \n"; + Operation *op = getOperation(); + RewritePatternSet patterns(op->getContext()); + populateLinalgFoldIntoElementwisePatterns(patterns); + + if (failed(applyPatternsGreedily(op, std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +void mlir::linalg::populateLinalgFoldIntoElementwisePatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} \ No newline at end of file diff --git a/mlir/test/Dialect/Linalg/elementwise/fold.mlir b/mlir/test/Dialect/Linalg/elementwise/fold.mlir new file mode 100644 index 0000000000000..e83c32fb6a2cf --- /dev/null +++ b/mlir/test/Dialect/Linalg/elementwise/fold.mlir @@ -0,0 +1,43 @@ +// RUN: mlir-opt %s -linalg-fold-into-elementwise -split-input-file | FileCheck %s + +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +// +// CHECK: func.func @unary_transpose(%[[A:.+]]: tensor<16x8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { +// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind +// CHECK-SAME: indexing_maps = [#[[TRANSPOSED]], #[[IDENTITY]]] +// CHECK-SAME: ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> +// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32> +// +func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { + %empty = tensor.empty() : tensor<8x16x32xf32> + %transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2] + %result = linalg.elementwise kind=#linalg.elementwise_kind + ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> + return %result : tensor<8x16x32xf32> +} + +// ----- + +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// +// CHECK: func.func @binary_transposed(%[[A:.+]]: tensor, %[[B:.+]]: tensor, %[[C:.+]]: tensor) -> tensor { +// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind +// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[TRANSPOSED]], #[[IDENTITY]]] +// CHECK-SAME: ins(%[[A]], %[[B]] : tensor, tensor) outs(%[[C]] : tensor) -> tensor +// CHECK-NEXT: return %[[RES]] : tensor +// +func.func @binary_transposed(%A : tensor, %B: tensor, %C: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %A, %c0 : tensor + %dim1 = tensor.dim %A, %c1 : tensor + + %empty = tensor.empty(%dim1, %dim0) : tensor + %transposed_B = linalg.transpose ins(%B : tensor) outs(%empty : tensor) permutation = [1, 0] + %result = linalg.elementwise kind=#linalg.elementwise_kind + ins(%A, %transposed_B : tensor, tensor) + outs(%C: tensor) -> tensor + return %result : tensor +}