@@ -30,24 +30,27 @@ include "mlir/IR/OpAsmInterface.td"
3030// RelayoutOp
3131//===----------------------------------------------------------------------===//
3232
33- class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
34- Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
35- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
36- DestinationStyleOpInterface, LinalgRelayoutOpInterface,
37- ConditionallySpeculatable, NoMemoryEffect,
38- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
33+ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []>
34+ : Op<Linalg_Dialect, mnemonic,
35+ !listconcat(
36+ traits, [DeclareOpInterfaceMethods<
37+ OpAsmOpInterface, ["getAsmResultNames"]>,
38+ DestinationStyleOpInterface, LinalgRelayoutOpInterface,
39+ ConditionallySpeculatable,
40+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
41+ DeclareOpInterfaceMethods<
42+ ReifyRankedShapedTypeOpInterface, [
3943 "reifyResultShapes"]>,
40- TypesMatchWith<"result type matches type of dest",
41- "dest", "result",
42- "$_self">])> {
44+ OptionalTypesMatchWith<"result type matches type of dest",
45+ "dest", "result", "$_self">])> {
4346
4447 code commonExtraClassDeclaration = [{
4548 size_t getSourceRank() { return getSourceType().getRank(); };
4649 size_t getDestRank() { return getDestType().getRank(); };
47- RankedTensorType getSourceType() {
48- return ::llvm::cast<RankedTensorType >(getSource().getType()); };
49- RankedTensorType getDestType() {
50- return ::llvm::cast<RankedTensorType >(getDest().getType()); };
50+ ShapedType getSourceType() {
51+ return ::llvm::cast<ShapedType >(getSource().getType()); };
52+ ShapedType getDestType() {
53+ return ::llvm::cast<ShapedType >(getDest().getType()); };
5154
5255 MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
5356
@@ -192,23 +195,12 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
192195 // expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
193196 ```
194197 }];
195- let arguments = (ins AnyRankedTensor:$source,
196- AnyRankedTensor:$dest,
197- Optional<AnyType>:$padding_value,
198- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
199- DenseI64ArrayAttr:$inner_dims_pos,
200- Variadic<Index>:$inner_tiles,
201- DenseI64ArrayAttr:$static_inner_tiles);
202- let results = (outs AnyRankedTensor:$result);
203- let assemblyFormat = [{
204- $source
205- (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
206- (`outer_dims_perm` `=` $outer_dims_perm^)?
207- `inner_dims_pos` `=` $inner_dims_pos
208- `inner_tiles` `=`
209- custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
210- `into` $dest attr-dict `:` type($source) `->` type($dest)
211- }];
198+ let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
199+ TensorOrMemRef<[AnyType]>:$dest, Optional<AnyType>:$padding_value,
200+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
201+ DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
202+ DenseI64ArrayAttr:$static_inner_tiles);
203+ let results = (outs Optional<AnyRankedTensor>:$result);
212204
213205 let builders = [
214206 OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -218,7 +210,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
218210 CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
219211 ];
220212
221- let extraClassDeclaration = commonExtraClassDeclaration # [{
213+ let extraClassDeclaration = commonExtraClassDeclaration# [{
222214 // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
223215 // This is a static method to allow getting the shape of the destination
224216 // expected while creating a `pack` op.
@@ -230,7 +222,19 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
230222 // Method to get the `RankedTensorType` of the result based on the inner
231223 // tiles, position of the inner tiles (innerDimsPos) and interchange vector
232224 // of outer loops (outerDimsPerm).
233- static RankedTensorType inferPackedType(RankedTensorType sourceType,
225+ static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
226+ ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
227+ ArrayRef<int64_t> outerDimsPerm = {});
228+
229+ // Method to get the `MemRefType` of the result based on the inner
230+ // tiles, position of the inner tiles (innerDimsPos) and interchange vector
231+ // of outer loops (outerDimsPerm).
232+ static MemRefType inferPackedMemRefType(MemRefType sourceType,
233+ ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
234+ ArrayRef<int64_t> outerDimsPerm = {});
235+
236+ // Returns the shape of the packed type. It is a shared helper helps type inference methods in a way that ensures that they agree on which dimensions are dynamic.
237+ static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
234238 ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
235239 ArrayRef<int64_t> outerDimsPerm = {});
236240
@@ -282,6 +286,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
282286 let hasCanonicalizeMethod = 1;
283287
284288 let hasFolder = 1;
289+
290+ let hasCustomAssemblyFormat = 1;
285291}
286292
287293//===----------------------------------------------------------------------===//
@@ -349,21 +355,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
349355 // Outer Dims: 9x3x8 Inner Dims: 4x2
350356 ```
351357 }];
352- let arguments = (ins AnyRankedTensor:$source,
353- AnyRankedTensor:$dest,
354- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
355- DenseI64ArrayAttr:$inner_dims_pos,
356- Variadic<Index>:$inner_tiles,
357- DenseI64ArrayAttr:$static_inner_tiles);
358- let results = (outs AnyRankedTensor:$result);
359- let assemblyFormat = [{
360- $source
361- (`outer_dims_perm` `=` $outer_dims_perm^)?
362- `inner_dims_pos` `=` $inner_dims_pos
363- `inner_tiles` `=`
364- custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
365- `into` $dest attr-dict `:` type($source) `->` type($dest)
366- }];
358+ let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
359+ TensorOrMemRef<[AnyType]>:$dest,
360+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
361+ DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
362+ DenseI64ArrayAttr:$static_inner_tiles);
363+ let results = (outs Optional<AnyRankedTensor>:$result);
367364
368365 let builders = [
369366 OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -406,6 +403,8 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
406403 let hasCanonicalizeMethod = 1;
407404
408405 let hasFolder = 1;
406+
407+ let hasCustomAssemblyFormat = 1;
409408}
410409
411410#endif // LINALG_RELEAYOUT_OPS
0 commit comments