Skip to content

Commit a755aec

Browse files
sakupan102ita9naiwa
andcommitted
[MLIR] Extend linalg.pack and linalg.unpack to accept memref
Extend linalg.pack and linalg.unpack to accept memref operands in addition to tensors. As part of this change, we now disable all transformations when these ops have memref semantics. Closes #129004 Co-authored-by: Hyunsung Lee <[email protected]> Signed-off-by: Ryutaro Okada <[email protected]>
1 parent 7734276 commit a755aec

File tree

11 files changed

+766
-130
lines changed

11 files changed

+766
-130
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 46 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,26 @@ include "mlir/IR/OpAsmInterface.td"
3030
// RelayoutOp
3131
//===----------------------------------------------------------------------===//
3232

33-
class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
34-
Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
35-
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
36-
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
37-
ConditionallySpeculatable, NoMemoryEffect,
38-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
39-
TypesMatchWith<"result type matches type of dest",
40-
"dest", "result",
41-
"$_self">])> {
33+
class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []>
34+
: Op<Linalg_Dialect, mnemonic,
35+
!listconcat(
36+
traits, [DeclareOpInterfaceMethods<
37+
OpAsmOpInterface, ["getAsmResultNames"]>,
38+
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
39+
ConditionallySpeculatable,
40+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
41+
DeclareOpInterfaceMethods<
42+
ReifyRankedShapedTypeOpInterface>,
43+
OptionalTypesMatchWith<"result type matches type of dest",
44+
"dest", "result", "$_self">])> {
4245

4346
code commonExtraClassDeclaration = [{
4447
size_t getSourceRank() { return getSourceType().getRank(); };
4548
size_t getDestRank() { return getDestType().getRank(); };
46-
RankedTensorType getSourceType() {
47-
return ::llvm::cast<RankedTensorType>(getSource().getType()); };
48-
RankedTensorType getDestType() {
49-
return ::llvm::cast<RankedTensorType>(getDest().getType()); };
49+
ShapedType getSourceType() {
50+
return ::llvm::cast<ShapedType>(getSource().getType()); };
51+
ShapedType getDestType() {
52+
return ::llvm::cast<ShapedType>(getDest().getType()); };
5053

5154
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
5255

@@ -191,23 +194,12 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
191194
// expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
192195
```
193196
}];
194-
let arguments = (ins AnyRankedTensor:$source,
195-
AnyRankedTensor:$dest,
196-
Optional<AnyType>:$padding_value,
197-
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
198-
DenseI64ArrayAttr:$inner_dims_pos,
199-
Variadic<Index>:$inner_tiles,
200-
DenseI64ArrayAttr:$static_inner_tiles);
201-
let results = (outs AnyRankedTensor:$result);
202-
let assemblyFormat = [{
203-
$source
204-
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
205-
(`outer_dims_perm` `=` $outer_dims_perm^)?
206-
`inner_dims_pos` `=` $inner_dims_pos
207-
`inner_tiles` `=`
208-
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
209-
`into` $dest attr-dict `:` type($source) `->` type($dest)
210-
}];
197+
let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
198+
TensorOrMemRef<[AnyType]>:$dest, Optional<AnyType>:$padding_value,
199+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
200+
DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
201+
DenseI64ArrayAttr:$static_inner_tiles);
202+
let results = (outs Optional<AnyRankedTensor>:$result);
211203

212204
let builders = [
213205
OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -217,7 +209,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
217209
CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
218210
];
219211

220-
let extraClassDeclaration = commonExtraClassDeclaration # [{
212+
let extraClassDeclaration = commonExtraClassDeclaration#[{
221213
// Method to get the shape of the result as `SmallVector<OpFoldResult>`.
222214
// This is a static method to allow getting the shape of the destination
223215
// expected while creating a `pack` op.
@@ -229,7 +221,19 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
229221
// Method to get the `RankedTensorType` of the result based on the inner
230222
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
231223
// of outer loops (outerDimsPerm).
232-
static RankedTensorType inferPackedType(RankedTensorType sourceType,
224+
static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
225+
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
226+
ArrayRef<int64_t> outerDimsPerm = {});
227+
228+
// Method to get the `MemRefType` of the result based on the inner
229+
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
230+
// of outer loops (outerDimsPerm).
231+
static MemRefType inferPackedMemRefType(MemRefType sourceType,
232+
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
233+
ArrayRef<int64_t> outerDimsPerm = {});
234+
235+
// Returns the shape of the packed type. It is a shared helper helps type inference methods in a way that ensures that they agree on which dimensions are dynamic.
236+
static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
233237
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
234238
ArrayRef<int64_t> outerDimsPerm = {});
235239

@@ -281,6 +285,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
281285
let hasCanonicalizeMethod = 1;
282286

283287
let hasFolder = 1;
288+
289+
let hasCustomAssemblyFormat = 1;
284290
}
285291

286292
//===----------------------------------------------------------------------===//
@@ -348,21 +354,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
348354
// Outer Dims: 9x3x8 Inner Dims: 4x2
349355
```
350356
}];
351-
let arguments = (ins AnyRankedTensor:$source,
352-
AnyRankedTensor:$dest,
353-
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
354-
DenseI64ArrayAttr:$inner_dims_pos,
355-
Variadic<Index>:$inner_tiles,
356-
DenseI64ArrayAttr:$static_inner_tiles);
357-
let results = (outs AnyRankedTensor:$result);
358-
let assemblyFormat = [{
359-
$source
360-
(`outer_dims_perm` `=` $outer_dims_perm^)?
361-
`inner_dims_pos` `=` $inner_dims_pos
362-
`inner_tiles` `=`
363-
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
364-
`into` $dest attr-dict `:` type($source) `->` type($dest)
365-
}];
357+
let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
358+
TensorOrMemRef<[AnyType]>:$dest,
359+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
360+
DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
361+
DenseI64ArrayAttr:$static_inner_tiles);
362+
let results = (outs Optional<AnyRankedTensor>:$result);
366363

367364
let builders = [
368365
OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -405,6 +402,8 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
405402
let hasCanonicalizeMethod = 1;
406403

407404
let hasFolder = 1;
405+
406+
let hasCustomAssemblyFormat = 1;
408407
}
409408

410409
#endif // LINALG_RELEAYOUT_OPS

0 commit comments

Comments
 (0)