Skip to content

Commit 4968fa7

Browse files
committed
[MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batch_matmul' operation.
Goals: 1. To add syntax and semantic to 'batch_matmul' without changing any of the existing syntax expectations for current usage. batch_matmul is still just batch_matmul. 2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS infra. Scope of this patch: To expose broadcast and transpose semantics on the 'batch_matmul'. The broadcast and transpose semantic is as follows: By default 'linalg.batch_matmul' behavior will remain as is. Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below.This is a list attribute, so the list must include all the maps if specified. Example Transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>) ``` Example Broadcast: ``` linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d3)>, //broadcast affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] ins(%arg0, %arg1 : memref<5xf32>,memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>) ``` Example Broadcast and transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, //broadcast affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ] ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>) ```
1 parent b275309 commit 4968fa7

File tree

8 files changed

+634
-88
lines changed

8 files changed

+634
-88
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,75 +1472,6 @@ structured_op: !LinalgStructuredOpConfig
14721472
- !ScalarExpression
14731473
scalar_arg: rhs
14741474
--- !LinalgOpConfig
1475-
metadata: !LinalgOpMetadata
1476-
name: batch_matmul
1477-
cpp_class_name: BatchMatmulOp
1478-
doc: |-
1479-
Performs a batched matrix multiplication of two 3D inputs.
1480-
1481-
Numeric casting is performed on the operands to the inner multiply, promoting
1482-
them to the same data type as the accumulator/output.
1483-
implements:
1484-
- LinalgContractionOpInterface
1485-
structured_op: !LinalgStructuredOpConfig
1486-
args:
1487-
- !LinalgOperandDefConfig
1488-
name: A
1489-
kind: input_tensor
1490-
type_var: T1
1491-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
1492-
- !LinalgOperandDefConfig
1493-
name: B
1494-
kind: input_tensor
1495-
type_var: T2
1496-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
1497-
- !LinalgOperandDefConfig
1498-
name: C
1499-
kind: output_tensor
1500-
type_var: U
1501-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
1502-
indexing_maps: !LinalgIndexingMapsConfig
1503-
static_indexing_maps:
1504-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
1505-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
1506-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
1507-
iterator_types:
1508-
- parallel
1509-
- parallel
1510-
- parallel
1511-
- reduction
1512-
assignments:
1513-
- !ScalarAssign
1514-
arg: C
1515-
value: !ScalarExpression
1516-
scalar_fn:
1517-
kind: binary
1518-
fn_name: add
1519-
operands:
1520-
- !ScalarExpression
1521-
scalar_arg: C
1522-
- !ScalarExpression
1523-
scalar_fn:
1524-
kind: binary
1525-
fn_name: mul
1526-
operands:
1527-
- !ScalarExpression
1528-
scalar_fn:
1529-
kind: type
1530-
fn_name: cast_signed
1531-
type_var: U
1532-
operands:
1533-
- !ScalarExpression
1534-
scalar_arg: A
1535-
- !ScalarExpression
1536-
scalar_fn:
1537-
kind: type
1538-
fn_name: cast_signed
1539-
type_var: U
1540-
operands:
1541-
- !ScalarExpression
1542-
scalar_arg: B
1543-
--- !LinalgOpConfig
15441475
metadata: !LinalgOpMetadata
15451476
name: batch_matmul_transpose_a
15461477
cpp_class_name: BatchMatmulTransposeAOp

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,130 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
816816
}];
817817
}
818818

819+
//===----------------------------------------------------------------------===//
820+
// Op definition for BatchMatmulOp
821+
//===----------------------------------------------------------------------===//
822+
823+
def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSizedOperandSegments],
824+
/*extraInterfaces=*/[LinalgContractionOpInterface])> {
825+
826+
let summary = [{Performs a batched matrix multiplication of two 3D inputs.}];
827+
let description = [{Numeric casting is performed on the operands to the inner multiply, promoting
828+
them to the same data type as the accumulator/output.
829+
830+
Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
831+
'indexing_maps' as shown below.This is a list attribute, so the list must include all
832+
the maps if specified.
833+
834+
Example Transpose:
835+
```
836+
linalg.batch_matmul indexing_maps = [
837+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
838+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
839+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
840+
]
841+
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
842+
outs(%arg2: memref<2x3x7xf32>)
843+
```
844+
845+
Example Broadcast:
846+
```
847+
linalg.batch_matmul indexing_maps = [
848+
affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
849+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
850+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
851+
]
852+
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
853+
outs(%arg2: memref<2x3x7xf32>)
854+
```
855+
856+
Example Broadcast and transpose:
857+
```
858+
linalg.batch_matmul indexing_maps = [
859+
affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
860+
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
861+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
862+
]
863+
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
864+
outs(%arg2: memref<2x3x7xf32>)
865+
```
866+
}];
867+
868+
let arguments = (ins
869+
Variadic<AnyType>:$inputs,
870+
Variadic<AnyShaped>:$outputs,
871+
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
872+
);
873+
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
874+
let regions = (region AnyRegion:$region);
875+
876+
let skipDefaultBuilders = 1;
877+
let builders = [
878+
OpBuilder<
879+
(ins "ValueRange":$inputs, "ValueRange":$outputs,
880+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
881+
[{
882+
buildBatchMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
883+
attributes, BatchMatmulOp::getRegionBuilder(),
884+
BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
885+
}]>,
886+
OpBuilder<
887+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
888+
"ValueRange":$outputs,
889+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
890+
[{
891+
buildBatchMatmulOp($_builder, $_state, resultTensorTypes,
892+
inputs, outputs, attributes, BatchMatmulOp::getRegionBuilder(),
893+
BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
894+
}]>,
895+
OpBuilder<
896+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
897+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
898+
[{
899+
$_state.addOperands(operands);
900+
$_state.addAttributes(attributes);
901+
$_state.addTypes(resultTensorTypes);
902+
(void)$_state.addRegion(),
903+
BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext());
904+
}]>
905+
906+
];
907+
let hasCustomAssemblyFormat = 1;
908+
let hasFolder = 1;
909+
let hasVerifier = 1;
910+
911+
let extraClassDeclaration = structuredOpsBaseDecls # [{
912+
913+
SmallVector<utils::IteratorType> getIteratorTypesArray();
914+
static void regionBuilder(ImplicitLocOpBuilder &b,
915+
Block &block, ArrayRef<NamedAttribute> attrs);
916+
static std::function<void(ImplicitLocOpBuilder &,
917+
Block &, ArrayRef<NamedAttribute>)>
918+
getRegionBuilder() {
919+
return regionBuilder;
920+
}
921+
922+
/// Returns a list of AffineMap with the typical batch_matmul indexing charactristic.
923+
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
924+
925+
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
926+
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
927+
928+
::mlir::MutableOperandRange getDpsInitsMutable() {
929+
return getOutputsMutable();
930+
}
931+
932+
// Generic methods.
933+
static unsigned getNumRegionArgs();
934+
bool hasDynamicIndexingMaps() { return true; }
935+
std::string getLibraryCallName();
936+
/// Check if the op has broadcast and/or transpose semantic. Returns true if the
937+
/// user defined indexing maps are not equal to default map.
938+
bool hasUserDefinedMaps();
939+
}];
940+
}
941+
942+
819943
//===----------------------------------------------------------------------===//
820944
// Named Linalg ops, implemented as a declarative configurations of generic ops.
821945
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)