Skip to content

Commit 318eeca

Browse files
committed
*Added logic to update the indexing_map attribute for collapsed MatmulOp.
*Updated test names and comments for consistency.
1 parent 1f44f8f commit 318eeca

File tree

4 files changed

+28
-19
lines changed

4 files changed

+28
-19
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -828,8 +828,8 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
828828
them to the same data type as the accumulator/output.
829829

830830
Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
831-
'indexing_maps' as shown below.This is a list attribute, so the list must include all
832-
the maps if specified.
831+
'indexing_maps' as shown below. This is a list attribute, so must include maps for all
832+
arguments if specified.
833833

834834
Example Transpose:
835835
```
@@ -845,15 +845,15 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
845845
Example Broadcast:
846846
```
847847
linalg.batch_matmul indexing_maps = [
848-
affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
848+
affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
849849
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
850850
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
851851
]
852852
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
853853
outs(%arg2: memref<2x3x7xf32>)
854854
```
855855

856-
Example Broadcast and transpose:
856+
Example Broadcast and Transpose:
857857
```
858858
linalg.batch_matmul indexing_maps = [
859859
affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
@@ -919,7 +919,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
919919
return regionBuilder;
920920
}
921921

922-
/// Returns a list of AffineMap with the typical batch_matmul indexing charactristic.
922+
/// Returns a list with default AffineMap(s), i.e. without broadcasts and transpositions.
923923
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
924924

925925
/// Returns true if the given broadcast map \p bcastMap is valid for this op.

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "llvm/ADT/SetVector.h"
3333
#include "llvm/Support/CommandLine.h"
3434
#include "llvm/Support/Debug.h"
35+
#include <type_traits>
3536

3637
namespace mlir {
3738
#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
@@ -908,11 +909,11 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
908909
PatternRewriter &rewriter) const override {
909910
// Check to not let go the batch_matmul with extended semantic, through this
910911
// transform.
911-
if (std::is_same<FromOpTy, BatchMatmulOp>::value) {
912+
if (std::is_same<FromOpTy, BatchMatmulOp>::value ||
913+
std::is_same<FromOpTy, MatmulOp>::value) {
912914
if (contractionOp.hasUserDefinedMaps()) {
913915
return rewriter.notifyMatchFailure(
914-
contractionOp,
915-
"only batch_matmul ops with non-extended semantics are supported");
916+
contractionOp, "ops with user-defined maps are not supported");
916917
}
917918
}
918919

@@ -944,10 +945,21 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
944945
loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
945946
ValueRange{collapsedInit});
946947
for (auto attr : contractionOp->getAttrs()) {
947-
if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
948-
attr.getName() == "indexing_maps")
948+
if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
949949
continue;
950-
collapsedOp->setAttr(attr.getName(), attr.getValue());
950+
951+
// Update the indexing_maps attribute for the collapsed MatmulOp.
952+
if (attr.getName() == "indexing_maps" &&
953+
std::is_same<FromOpTy, BatchMatmulOp>::value &&
954+
std::is_same<ToOpTy, MatmulOp>::value) {
955+
SmallVector<Attribute, 3> indexingMapsAttr = llvm::map_to_vector(
956+
MatmulOp::getDefaultIndexingMaps(rewriter.getContext()),
957+
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
958+
collapsedOp->setAttr(attr.getName(),
959+
rewriter.getArrayAttr(indexingMapsAttr));
960+
} else {
961+
collapsedOp->setAttr(attr.getName(), attr.getValue());
962+
}
951963
}
952964

953965
auto results = contractionOp.getResults();

mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,9 @@ FailureOr<Operation *>
8888
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
8989
linalg::BatchMatmulOp batchMatmulOp,
9090
bool transposeLHS) {
91-
// Check to not let go the batch_matmul with extended semantic, through this
92-
// transform.
9391
if (batchMatmulOp.hasUserDefinedMaps()) {
9492
return rewriter.notifyMatchFailure(
95-
batchMatmulOp,
96-
"only batch_matmul ops with non-extended semantics are supported");
93+
batchMatmulOp, "ops with user-defined maps are not supported");
9794
}
9895

9996
if (!bufferization::hasTensorSemantics(batchMatmulOp))

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,11 +1578,11 @@ func.func @batch_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memre
15781578

15791579
// -----
15801580

1581-
// CHECK-LABEL: func @batch_matmul_explicit_transpose_a
1581+
// CHECK-LABEL: func @batch_matmul_explicit_transpose_A
15821582
// CHECK: linalg.batch_matmul
15831583
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
15841584
// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>)
1585-
func.func @batch_matmul_explicit_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
1585+
func.func @batch_matmul_explicit_transpose_A(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
15861586
linalg.batch_matmul indexing_maps = [
15871587
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
15881588
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
@@ -1594,11 +1594,11 @@ func.func @batch_matmul_explicit_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: me
15941594

15951595
// -----
15961596

1597-
// CHECK-LABEL: func @batch_matmul_explicit_transpose_b
1597+
// CHECK-LABEL: func @batch_matmul_explicit_transpose_B
15981598
// CHECK: linalg.batch_matmul
15991599
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>)
16001600
// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>)
1601-
func.func @batch_matmul_explicit_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
1601+
func.func @batch_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
16021602
linalg.batch_matmul indexing_maps = [
16031603
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
16041604
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,

0 commit comments

Comments
 (0)