@@ -30,23 +30,26 @@ include "mlir/IR/OpAsmInterface.td"
3030// RelayoutOp
3131//===----------------------------------------------------------------------===//
3232
33- class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
34- Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
35- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
36- DestinationStyleOpInterface, LinalgRelayoutOpInterface,
37- ConditionallySpeculatable, NoMemoryEffect,
38- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
39- TypesMatchWith<"result type matches type of dest",
40- "dest", "result",
41- "$_self">])> {
33+ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []>
34+ : Op<Linalg_Dialect, mnemonic,
35+ !listconcat(
36+ traits, [DeclareOpInterfaceMethods<
37+ OpAsmOpInterface, ["getAsmResultNames"]>,
38+ DestinationStyleOpInterface, LinalgRelayoutOpInterface,
39+ ConditionallySpeculatable,
40+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
41+ DeclareOpInterfaceMethods<
42+ ReifyRankedShapedTypeOpInterface>,
43+ OptionalTypesMatchWith<"result type matches type of dest",
44+ "dest", "result", "$_self">])> {
4245
4346 code commonExtraClassDeclaration = [{
4447 size_t getSourceRank() { return getSourceType().getRank(); };
4548 size_t getDestRank() { return getDestType().getRank(); };
46- RankedTensorType getSourceType() {
47- return ::llvm::cast<RankedTensorType >(getSource().getType()); };
48- RankedTensorType getDestType() {
49- return ::llvm::cast<RankedTensorType >(getDest().getType()); };
49+ ShapedType getSourceType() {
50+ return ::llvm::cast<ShapedType >(getSource().getType()); };
51+ ShapedType getDestType() {
52+ return ::llvm::cast<ShapedType >(getDest().getType()); };
5053
5154 MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
5255
@@ -191,23 +194,12 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
191194 // expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
192195 ```
193196 }];
194- let arguments = (ins AnyRankedTensor:$source,
195- AnyRankedTensor:$dest,
196- Optional<AnyType>:$padding_value,
197- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
198- DenseI64ArrayAttr:$inner_dims_pos,
199- Variadic<Index>:$inner_tiles,
200- DenseI64ArrayAttr:$static_inner_tiles);
201- let results = (outs AnyRankedTensor:$result);
202- let assemblyFormat = [{
203- $source
204- (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
205- (`outer_dims_perm` `=` $outer_dims_perm^)?
206- `inner_dims_pos` `=` $inner_dims_pos
207- `inner_tiles` `=`
208- custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
209- `into` $dest attr-dict `:` type($source) `->` type($dest)
210- }];
197+ let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
198+ TensorOrMemRef<[AnyType]>:$dest, Optional<AnyType>:$padding_value,
199+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
200+ DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
201+ DenseI64ArrayAttr:$static_inner_tiles);
202+ let results = (outs Optional<AnyRankedTensor>:$result);
211203
212204 let builders = [
213205 OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -217,7 +209,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
217209 CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
218210 ];
219211
220- let extraClassDeclaration = commonExtraClassDeclaration # [{
212+ let extraClassDeclaration = commonExtraClassDeclaration# [{
221213 // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
222214 // This is a static method to allow getting the shape of the destination
223215 // expected while creating a `pack` op.
@@ -229,7 +221,19 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
229221 // Method to get the `RankedTensorType` of the result based on the inner
230222 // tiles, position of the inner tiles (innerDimsPos) and interchange vector
231223 // of outer loops (outerDimsPerm).
232- static RankedTensorType inferPackedType(RankedTensorType sourceType,
224+ static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
225+ ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
226+ ArrayRef<int64_t> outerDimsPerm = {});
227+
228+ // Method to get the `MemRefType` of the result based on the inner
229+ // tiles, position of the inner tiles (innerDimsPos) and interchange vector
230+ // of outer loops (outerDimsPerm).
231+ static MemRefType inferPackedMemRefType(MemRefType sourceType,
232+ ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
233+ ArrayRef<int64_t> outerDimsPerm = {});
234+
235+ // Returns the shape of the packed type. It is a shared helper helps type inference methods in a way that ensures that they agree on which dimensions are dynamic.
236+ static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
233237 ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
234238 ArrayRef<int64_t> outerDimsPerm = {});
235239
@@ -281,6 +285,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
281285 let hasCanonicalizeMethod = 1;
282286
283287 let hasFolder = 1;
288+
289+ let hasCustomAssemblyFormat = 1;
284290}
285291
286292//===----------------------------------------------------------------------===//
@@ -348,21 +354,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
348354 // Outer Dims: 9x3x8 Inner Dims: 4x2
349355 ```
350356 }];
351- let arguments = (ins AnyRankedTensor:$source,
352- AnyRankedTensor:$dest,
353- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
354- DenseI64ArrayAttr:$inner_dims_pos,
355- Variadic<Index>:$inner_tiles,
356- DenseI64ArrayAttr:$static_inner_tiles);
357- let results = (outs AnyRankedTensor:$result);
358- let assemblyFormat = [{
359- $source
360- (`outer_dims_perm` `=` $outer_dims_perm^)?
361- `inner_dims_pos` `=` $inner_dims_pos
362- `inner_tiles` `=`
363- custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
364- `into` $dest attr-dict `:` type($source) `->` type($dest)
365- }];
357+ let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
358+ TensorOrMemRef<[AnyType]>:$dest,
359+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
360+ DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
361+ DenseI64ArrayAttr:$static_inner_tiles);
362+ let results = (outs Optional<AnyRankedTensor>:$result);
366363
367364 let builders = [
368365 OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -405,6 +402,8 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
405402 let hasCanonicalizeMethod = 1;
406403
407404 let hasFolder = 1;
405+
406+ let hasCustomAssemblyFormat = 1;
408407}
409408
410409#endif // LINALG_RELEAYOUT_OPS
0 commit comments