Skip to content

Commit 0d667f5

Browse files
sakupan102ita9naiwa
andcommitted
[MLIR] Extend linalg.pack and linalg.unpack to accept memref
Extend linalg.pack and linalg.unpack to accept memref operands in addition to tensors. As part of this change, we now disable all transformations when these ops have memref semantics. Closes #129004 Co-authored-by: Hyunsung Lee <[email protected]> Signed-off-by: Ryutaro Okada <[email protected]>
1 parent 02c68b3 commit 0d667f5

File tree

11 files changed

+771
-135
lines changed

11 files changed

+771
-135
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 47 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This file defines Pack + Unpack Ops that have been moved from the Tensor
10-
// dialect. As such, these are defined as memory-effect-free and only accept
11-
// "tensors" as inputs.
12-
//
13-
// TODO: Once a good motivating example is identified, relax these
14-
// restrictions.
10+
// dialect.
1511
//
1612
//===----------------------------------------------------------------------===//
1713

@@ -30,24 +26,27 @@ include "mlir/IR/OpAsmInterface.td"
3026
// RelayoutOp
3127
//===----------------------------------------------------------------------===//
3228

33-
class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
34-
Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
35-
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
36-
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
37-
ConditionallySpeculatable, NoMemoryEffect,
38-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
29+
class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []>
30+
: Op<Linalg_Dialect, mnemonic,
31+
!listconcat(
32+
traits, [DeclareOpInterfaceMethods<
33+
OpAsmOpInterface, ["getAsmResultNames"]>,
34+
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
35+
ConditionallySpeculatable,
36+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
37+
DeclareOpInterfaceMethods<
38+
ReifyRankedShapedTypeOpInterface, [
3939
"reifyResultShapes"]>,
40-
TypesMatchWith<"result type matches type of dest",
41-
"dest", "result",
42-
"$_self">])> {
40+
OptionalTypesMatchWith<"result type matches type of dest",
41+
"dest", "result", "$_self">])> {
4342

4443
code commonExtraClassDeclaration = [{
4544
size_t getSourceRank() { return getSourceType().getRank(); };
4645
size_t getDestRank() { return getDestType().getRank(); };
47-
RankedTensorType getSourceType() {
48-
return ::llvm::cast<RankedTensorType>(getSource().getType()); };
49-
RankedTensorType getDestType() {
50-
return ::llvm::cast<RankedTensorType>(getDest().getType()); };
46+
ShapedType getSourceType() {
47+
return ::llvm::cast<ShapedType>(getSource().getType()); };
48+
ShapedType getDestType() {
49+
return ::llvm::cast<ShapedType>(getDest().getType()); };
5150

5251
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
5352

@@ -192,23 +191,12 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
192191
// expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
193192
```
194193
}];
195-
let arguments = (ins AnyRankedTensor:$source,
196-
AnyRankedTensor:$dest,
197-
Optional<AnyType>:$padding_value,
198-
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
199-
DenseI64ArrayAttr:$inner_dims_pos,
200-
Variadic<Index>:$inner_tiles,
201-
DenseI64ArrayAttr:$static_inner_tiles);
202-
let results = (outs AnyRankedTensor:$result);
203-
let assemblyFormat = [{
204-
$source
205-
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
206-
(`outer_dims_perm` `=` $outer_dims_perm^)?
207-
`inner_dims_pos` `=` $inner_dims_pos
208-
`inner_tiles` `=`
209-
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
210-
`into` $dest attr-dict `:` type($source) `->` type($dest)
211-
}];
194+
let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
195+
TensorOrMemRef<[AnyType]>:$dest, Optional<AnyType>:$padding_value,
196+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
197+
DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
198+
DenseI64ArrayAttr:$static_inner_tiles);
199+
let results = (outs Optional<AnyRankedTensor>:$result);
212200

213201
let builders = [
214202
OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -218,7 +206,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
218206
CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
219207
];
220208

221-
let extraClassDeclaration = commonExtraClassDeclaration # [{
209+
let extraClassDeclaration = commonExtraClassDeclaration#[{
222210
// Method to get the shape of the result as `SmallVector<OpFoldResult>`.
223211
// This is a static method to allow getting the shape of the destination
224212
// expected while creating a `pack` op.
@@ -230,7 +218,19 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
230218
// Method to get the `RankedTensorType` of the result based on the inner
231219
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
232220
// of outer loops (outerDimsPerm).
233-
static RankedTensorType inferPackedType(RankedTensorType sourceType,
221+
static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
222+
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
223+
ArrayRef<int64_t> outerDimsPerm = {});
224+
225+
// Method to get the `MemRefType` of the result based on the inner
226+
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
227+
// of outer loops (outerDimsPerm).
228+
static MemRefType inferPackedMemRefType(MemRefType sourceType,
229+
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
230+
ArrayRef<int64_t> outerDimsPerm = {});
231+
232+
// Returns the shape of the packed type. It is a shared helper helps type inference methods in a way that ensures that they agree on which dimensions are dynamic.
233+
static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
234234
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
235235
ArrayRef<int64_t> outerDimsPerm = {});
236236

@@ -282,6 +282,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
282282
let hasCanonicalizeMethod = 1;
283283

284284
let hasFolder = 1;
285+
286+
let hasCustomAssemblyFormat = 1;
285287
}
286288

287289
//===----------------------------------------------------------------------===//
@@ -349,21 +351,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
349351
// Outer Dims: 9x3x8 Inner Dims: 4x2
350352
```
351353
}];
352-
let arguments = (ins AnyRankedTensor:$source,
353-
AnyRankedTensor:$dest,
354-
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
355-
DenseI64ArrayAttr:$inner_dims_pos,
356-
Variadic<Index>:$inner_tiles,
357-
DenseI64ArrayAttr:$static_inner_tiles);
358-
let results = (outs AnyRankedTensor:$result);
359-
let assemblyFormat = [{
360-
$source
361-
(`outer_dims_perm` `=` $outer_dims_perm^)?
362-
`inner_dims_pos` `=` $inner_dims_pos
363-
`inner_tiles` `=`
364-
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
365-
`into` $dest attr-dict `:` type($source) `->` type($dest)
366-
}];
354+
let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
355+
TensorOrMemRef<[AnyType]>:$dest,
356+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
357+
DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
358+
DenseI64ArrayAttr:$static_inner_tiles);
359+
let results = (outs Optional<AnyRankedTensor>:$result);
367360

368361
let builders = [
369362
OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -406,6 +399,8 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
406399
let hasCanonicalizeMethod = 1;
407400

408401
let hasFolder = 1;
402+
403+
let hasCustomAssemblyFormat = 1;
409404
}
410405

411406
#endif // LINALG_RELEAYOUT_OPS

0 commit comments

Comments
 (0)