Skip to content

Commit 954c49c

Browse files
sakupan102ita9naiwa
andcommitted
[MLIR] Extend linalg.pack and linalg.unpack to accept memref
Extend linalg.pack and linalg.unpack to accept memref operands in addition to tensors. As part of this change, we now disable all transformations when these ops have memref semantics. Closes #129004 Co-authored-by: Hyunsung Lee <[email protected]> Signed-off-by: Ryutaro Okada <[email protected]>
1 parent 02c68b3 commit 954c49c

File tree

11 files changed

+770
-130
lines changed

11 files changed

+770
-130
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 46 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,27 @@ include "mlir/IR/OpAsmInterface.td"
3030
// RelayoutOp
3131
//===----------------------------------------------------------------------===//
3232

33-
class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
34-
Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
35-
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
36-
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
37-
ConditionallySpeculatable, NoMemoryEffect,
38-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
33+
class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []>
34+
: Op<Linalg_Dialect, mnemonic,
35+
!listconcat(
36+
traits, [DeclareOpInterfaceMethods<
37+
OpAsmOpInterface, ["getAsmResultNames"]>,
38+
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
39+
ConditionallySpeculatable,
40+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
41+
DeclareOpInterfaceMethods<
42+
ReifyRankedShapedTypeOpInterface, [
3943
"reifyResultShapes"]>,
40-
TypesMatchWith<"result type matches type of dest",
41-
"dest", "result",
42-
"$_self">])> {
44+
OptionalTypesMatchWith<"result type matches type of dest",
45+
"dest", "result", "$_self">])> {
4346

4447
code commonExtraClassDeclaration = [{
4548
size_t getSourceRank() { return getSourceType().getRank(); };
4649
size_t getDestRank() { return getDestType().getRank(); };
47-
RankedTensorType getSourceType() {
48-
return ::llvm::cast<RankedTensorType>(getSource().getType()); };
49-
RankedTensorType getDestType() {
50-
return ::llvm::cast<RankedTensorType>(getDest().getType()); };
50+
ShapedType getSourceType() {
51+
return ::llvm::cast<ShapedType>(getSource().getType()); };
52+
ShapedType getDestType() {
53+
return ::llvm::cast<ShapedType>(getDest().getType()); };
5154

5255
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
5356

@@ -192,23 +195,12 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
192195
// expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
193196
```
194197
}];
195-
let arguments = (ins AnyRankedTensor:$source,
196-
AnyRankedTensor:$dest,
197-
Optional<AnyType>:$padding_value,
198-
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
199-
DenseI64ArrayAttr:$inner_dims_pos,
200-
Variadic<Index>:$inner_tiles,
201-
DenseI64ArrayAttr:$static_inner_tiles);
202-
let results = (outs AnyRankedTensor:$result);
203-
let assemblyFormat = [{
204-
$source
205-
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
206-
(`outer_dims_perm` `=` $outer_dims_perm^)?
207-
`inner_dims_pos` `=` $inner_dims_pos
208-
`inner_tiles` `=`
209-
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
210-
`into` $dest attr-dict `:` type($source) `->` type($dest)
211-
}];
198+
let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
199+
TensorOrMemRef<[AnyType]>:$dest, Optional<AnyType>:$padding_value,
200+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
201+
DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
202+
DenseI64ArrayAttr:$static_inner_tiles);
203+
let results = (outs Optional<AnyRankedTensor>:$result);
212204

213205
let builders = [
214206
OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -218,7 +210,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
218210
CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
219211
];
220212

221-
let extraClassDeclaration = commonExtraClassDeclaration # [{
213+
let extraClassDeclaration = commonExtraClassDeclaration#[{
222214
// Method to get the shape of the result as `SmallVector<OpFoldResult>`.
223215
// This is a static method to allow getting the shape of the destination
224216
// expected while creating a `pack` op.
@@ -230,7 +222,19 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
230222
// Method to get the `RankedTensorType` of the result based on the inner
231223
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
232224
// of outer loops (outerDimsPerm).
233-
static RankedTensorType inferPackedType(RankedTensorType sourceType,
225+
static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
226+
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
227+
ArrayRef<int64_t> outerDimsPerm = {});
228+
229+
// Method to get the `MemRefType` of the result based on the inner
230+
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
231+
// of outer loops (outerDimsPerm).
232+
static MemRefType inferPackedMemRefType(MemRefType sourceType,
233+
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
234+
ArrayRef<int64_t> outerDimsPerm = {});
235+
236+
// Returns the shape of the packed type. It is a shared helper helps type inference methods in a way that ensures that they agree on which dimensions are dynamic.
237+
static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
234238
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
235239
ArrayRef<int64_t> outerDimsPerm = {});
236240

@@ -282,6 +286,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
282286
let hasCanonicalizeMethod = 1;
283287

284288
let hasFolder = 1;
289+
290+
let hasCustomAssemblyFormat = 1;
285291
}
286292

287293
//===----------------------------------------------------------------------===//
@@ -349,21 +355,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
349355
// Outer Dims: 9x3x8 Inner Dims: 4x2
350356
```
351357
}];
352-
let arguments = (ins AnyRankedTensor:$source,
353-
AnyRankedTensor:$dest,
354-
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
355-
DenseI64ArrayAttr:$inner_dims_pos,
356-
Variadic<Index>:$inner_tiles,
357-
DenseI64ArrayAttr:$static_inner_tiles);
358-
let results = (outs AnyRankedTensor:$result);
359-
let assemblyFormat = [{
360-
$source
361-
(`outer_dims_perm` `=` $outer_dims_perm^)?
362-
`inner_dims_pos` `=` $inner_dims_pos
363-
`inner_tiles` `=`
364-
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
365-
`into` $dest attr-dict `:` type($source) `->` type($dest)
366-
}];
358+
let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
359+
TensorOrMemRef<[AnyType]>:$dest,
360+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
361+
DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
362+
DenseI64ArrayAttr:$static_inner_tiles);
363+
let results = (outs Optional<AnyRankedTensor>:$result);
367364

368365
let builders = [
369366
OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -406,6 +403,8 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
406403
let hasCanonicalizeMethod = 1;
407404

408405
let hasFolder = 1;
406+
407+
let hasCustomAssemblyFormat = 1;
409408
}
410409

411410
#endif // LINALG_RELEAYOUT_OPS

0 commit comments

Comments
 (0)