77//===----------------------------------------------------------------------===//
88//
99// This file defines Pack + Unpack Ops that have been moved from the Tensor
10- // dialect. As such, these are defined as memory-effect-free and only accept
11- // "tensors" as inputs.
12- //
13- // TODO: Once a good motivating example is identified, relax these
14- // restrictions.
10+ // dialect.
1511//
1612//===----------------------------------------------------------------------===//
1713
@@ -30,24 +26,27 @@ include "mlir/IR/OpAsmInterface.td"
3026// RelayoutOp
3127//===----------------------------------------------------------------------===//
3228
33- class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
34- Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
35- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
36- DestinationStyleOpInterface, LinalgRelayoutOpInterface,
37- ConditionallySpeculatable, NoMemoryEffect,
38- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
29+ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []>
30+ : Op<Linalg_Dialect, mnemonic,
31+ !listconcat(
32+ traits, [DeclareOpInterfaceMethods<
33+ OpAsmOpInterface, ["getAsmResultNames"]>,
34+ DestinationStyleOpInterface, LinalgRelayoutOpInterface,
35+ ConditionallySpeculatable,
36+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
37+ DeclareOpInterfaceMethods<
38+ ReifyRankedShapedTypeOpInterface, [
3939 "reifyResultShapes"]>,
40- TypesMatchWith<"result type matches type of dest",
41- "dest", "result",
42- "$_self">])> {
40+ OptionalTypesMatchWith<"result type matches type of dest",
41+ "dest", "result", "$_self">])> {
4342
4443 code commonExtraClassDeclaration = [{
4544 size_t getSourceRank() { return getSourceType().getRank(); };
4645 size_t getDestRank() { return getDestType().getRank(); };
47- RankedTensorType getSourceType() {
48- return ::llvm::cast<RankedTensorType >(getSource().getType()); };
49- RankedTensorType getDestType() {
50- return ::llvm::cast<RankedTensorType >(getDest().getType()); };
46+ ShapedType getSourceType() {
47+ return ::llvm::cast<ShapedType >(getSource().getType()); };
48+ ShapedType getDestType() {
49+ return ::llvm::cast<ShapedType >(getDest().getType()); };
5150
5251 MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
5352
@@ -192,23 +191,12 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
192191 // expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
193192 ```
194193 }];
195- let arguments = (ins AnyRankedTensor:$source,
196- AnyRankedTensor:$dest,
197- Optional<AnyType>:$padding_value,
198- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
199- DenseI64ArrayAttr:$inner_dims_pos,
200- Variadic<Index>:$inner_tiles,
201- DenseI64ArrayAttr:$static_inner_tiles);
202- let results = (outs AnyRankedTensor:$result);
203- let assemblyFormat = [{
204- $source
205- (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
206- (`outer_dims_perm` `=` $outer_dims_perm^)?
207- `inner_dims_pos` `=` $inner_dims_pos
208- `inner_tiles` `=`
209- custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
210- `into` $dest attr-dict `:` type($source) `->` type($dest)
211- }];
194+ let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
195+ TensorOrMemRef<[AnyType]>:$dest, Optional<AnyType>:$padding_value,
196+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
197+ DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
198+ DenseI64ArrayAttr:$static_inner_tiles);
199+ let results = (outs Optional<AnyRankedTensor>:$result);
212200
213201 let builders = [
214202 OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -218,7 +206,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
218206 CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
219207 ];
220208
221- let extraClassDeclaration = commonExtraClassDeclaration # [{
209+ let extraClassDeclaration = commonExtraClassDeclaration# [{
222210 // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
223211 // This is a static method to allow getting the shape of the destination
224212 // expected while creating a `pack` op.
@@ -230,7 +218,19 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
230218 // Method to get the `RankedTensorType` of the result based on the inner
231219 // tiles, position of the inner tiles (innerDimsPos) and interchange vector
232220 // of outer loops (outerDimsPerm).
233- static RankedTensorType inferPackedType(RankedTensorType sourceType,
221+ static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
222+ ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
223+ ArrayRef<int64_t> outerDimsPerm = {});
224+
225+ // Method to get the `MemRefType` of the result based on the inner
226+ // tiles, position of the inner tiles (innerDimsPos) and interchange vector
227+ // of outer loops (outerDimsPerm).
228+ static MemRefType inferPackedMemRefType(MemRefType sourceType,
229+ ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
230+ ArrayRef<int64_t> outerDimsPerm = {});
231+
232+ // Returns the shape of the packed type. It is a shared helper helps type inference methods in a way that ensures that they agree on which dimensions are dynamic.
233+ static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
234234 ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
235235 ArrayRef<int64_t> outerDimsPerm = {});
236236
@@ -282,6 +282,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
282282 let hasCanonicalizeMethod = 1;
283283
284284 let hasFolder = 1;
285+
286+ let hasCustomAssemblyFormat = 1;
285287}
286288
287289//===----------------------------------------------------------------------===//
@@ -349,21 +351,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
349351 // Outer Dims: 9x3x8 Inner Dims: 4x2
350352 ```
351353 }];
352- let arguments = (ins AnyRankedTensor:$source,
353- AnyRankedTensor:$dest,
354- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
355- DenseI64ArrayAttr:$inner_dims_pos,
356- Variadic<Index>:$inner_tiles,
357- DenseI64ArrayAttr:$static_inner_tiles);
358- let results = (outs AnyRankedTensor:$result);
359- let assemblyFormat = [{
360- $source
361- (`outer_dims_perm` `=` $outer_dims_perm^)?
362- `inner_dims_pos` `=` $inner_dims_pos
363- `inner_tiles` `=`
364- custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
365- `into` $dest attr-dict `:` type($source) `->` type($dest)
366- }];
354+ let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
355+ TensorOrMemRef<[AnyType]>:$dest,
356+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
357+ DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
358+ DenseI64ArrayAttr:$static_inner_tiles);
359+ let results = (outs Optional<AnyRankedTensor>:$result);
367360
368361 let builders = [
369362 OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -406,6 +399,8 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
406399 let hasCanonicalizeMethod = 1;
407400
408401 let hasFolder = 1;
402+
403+ let hasCustomAssemblyFormat = 1;
409404}
410405
411406#endif // LINALG_RELEAYOUT_OPS
0 commit comments