Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 47 additions & 52 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file defines Pack + Unpack Ops that have been moved from the Tensor
// dialect. As such, these are defined as memory-effect-free and only accept
// "tensors" as inputs.
//
// TODO: Once a good motivating example is identified, relax these
// restrictions.
// dialect.
//
//===----------------------------------------------------------------------===//

Expand All @@ -30,24 +26,27 @@ include "mlir/IR/OpAsmInterface.td"
// RelayoutOp
//===----------------------------------------------------------------------===//

class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
ConditionallySpeculatable, NoMemoryEffect,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []>
: Op<Linalg_Dialect, mnemonic,
!listconcat(
traits, [DeclareOpInterfaceMethods<
OpAsmOpInterface, ["getAsmResultNames"]>,
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
ConditionallySpeculatable,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<
ReifyRankedShapedTypeOpInterface, [
"reifyResultShapes"]>,
TypesMatchWith<"result type matches type of dest",
"dest", "result",
"$_self">])> {
OptionalTypesMatchWith<"result type matches type of dest",
"dest", "result", "$_self">])> {

code commonExtraClassDeclaration = [{
size_t getSourceRank() { return getSourceType().getRank(); };
size_t getDestRank() { return getDestType().getRank(); };
RankedTensorType getSourceType() {
return ::llvm::cast<RankedTensorType>(getSource().getType()); };
RankedTensorType getDestType() {
return ::llvm::cast<RankedTensorType>(getDest().getType()); };
ShapedType getSourceType() {
return ::llvm::cast<ShapedType>(getSource().getType()); };
ShapedType getDestType() {
return ::llvm::cast<ShapedType>(getDest().getType()); };

MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }

Expand Down Expand Up @@ -192,23 +191,12 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
```
}];
let arguments = (ins AnyRankedTensor:$source,
AnyRankedTensor:$dest,
Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$source
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
(`outer_dims_perm` `=` $outer_dims_perm^)?
`inner_dims_pos` `=` $inner_dims_pos
`inner_tiles` `=`
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
`into` $dest attr-dict `:` type($source) `->` type($dest)
}];
let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
TensorOrMemRef<[AnyType]>:$dest, Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs Optional<AnyRankedTensor>:$result);

let builders = [
OpBuilder<(ins "Value":$source, "Value":$dest,
Expand All @@ -218,7 +206,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
];

let extraClassDeclaration = commonExtraClassDeclaration # [{
let extraClassDeclaration = commonExtraClassDeclaration#[{
// Method to get the shape of the result as `SmallVector<OpFoldResult>`.
// This is a static method to allow getting the shape of the destination
// expected while creating a `pack` op.
Expand All @@ -230,7 +218,19 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
static RankedTensorType inferPackedType(RankedTensorType sourceType,
static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});

// Method to get the `MemRefType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
static MemRefType inferPackedMemRefType(MemRefType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});

// Returns the shape of the packed type. It is a shared helper helps type inference methods in a way that ensures that they agree on which dimensions are dynamic.
static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});

Expand Down Expand Up @@ -282,6 +282,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
let hasCanonicalizeMethod = 1;

let hasFolder = 1;

let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -349,21 +351,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
// Outer Dims: 9x3x8 Inner Dims: 4x2
```
}];
let arguments = (ins AnyRankedTensor:$source,
AnyRankedTensor:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$source
(`outer_dims_perm` `=` $outer_dims_perm^)?
`inner_dims_pos` `=` $inner_dims_pos
`inner_tiles` `=`
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
`into` $dest attr-dict `:` type($source) `->` type($dest)
}];
let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
TensorOrMemRef<[AnyType]>:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs Optional<AnyRankedTensor>:$result);

let builders = [
OpBuilder<(ins "Value":$source, "Value":$dest,
Expand Down Expand Up @@ -406,6 +399,8 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
let hasCanonicalizeMethod = 1;

let hasFolder = 1;

let hasCustomAssemblyFormat = 1;
}

#endif // LINALG_RELEAYOUT_OPS
Loading