diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h index b969b60a66f16..b94a933b5c945 100644 --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h @@ -270,6 +270,12 @@ LogicalResult verifyAffineMapAsLayout(AffineMap m, ArrayRef shape, function_ref emitError); +// Return the strides and offsets that can be inferred from the given affine +// layout map given the map and a memref shape. +LogicalResult getAffineMapStridesAndOffset(AffineMap map, + ArrayRef shape, + SmallVectorImpl &strides, + int64_t &offset); } // namespace detail } // namespace mlir diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td index 6220d80264bdf..cf9697457f4d8 100644 --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -509,6 +509,23 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> { return ::mlir::detail::verifyAffineMapAsLayout($_attr.getAffineMap(), shape, emitError); }] + >, + + InterfaceMethod< + [{Return the strides (using ShapedType::kDynamic for the dynamic case) + that this layout corresponds to into `strides` and `offset` if such exist + and can be determined from a combination of the layout and the given + `shape`. If these strides cannot be inferred, return failure(). + The values of `strides` and `offset` are undefined on failure.}], + "::llvm::LogicalResult", "getStridesAndOffset", + (ins "::llvm::ArrayRef":$shape, + "::llvm::SmallVectorImpl&":$strides, + "int64_t&":$offset), + [{}], + [{ + return ::mlir::detail::getAffineMapStridesAndOffset( + $_attr.getAffineMap(), shape, strides, offset); + }] > ]; } diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index 0169f4b38bbe0..854a24ab8605c 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -1003,7 +1003,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr< def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout", [DeclareAttrInterfaceMethods]> { + ["verifyLayout", "getStridesAndOffset"]>]> { let summary = "An Attribute representing a strided layout of a shaped type"; let description = [{ Syntax: diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp index 9b5235a6c5ceb..9e8ce4ca3a902 100644 --- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp @@ -83,3 +83,138 @@ LogicalResult mlir::detail::verifyAffineMapAsLayout( return success(); } + +// Fallback cases for terminal dim/sym/cst that are not part of a binary op ( +// i.e. single term). Accumulate the AffineExpr into the existing one. +static void extractStridesFromTerm(AffineExpr e, + AffineExpr multiplicativeFactor, + MutableArrayRef strides, + AffineExpr &offset) { + if (auto dim = dyn_cast(e)) + strides[dim.getPosition()] = + strides[dim.getPosition()] + multiplicativeFactor; + else + offset = offset + e * multiplicativeFactor; +} + +/// Takes a single AffineExpr `e` and populates the `strides` array with the +/// strides expressions for each dim position. +/// The convention is that the strides for dimensions d0, .. dn appear in +/// order to make indexing intuitive into the result. +static LogicalResult extractStrides(AffineExpr e, + AffineExpr multiplicativeFactor, + MutableArrayRef strides, + AffineExpr &offset) { + auto bin = dyn_cast(e); + if (!bin) { + extractStridesFromTerm(e, multiplicativeFactor, strides, offset); + return success(); + } + + if (bin.getKind() == AffineExprKind::CeilDiv || + bin.getKind() == AffineExprKind::FloorDiv || + bin.getKind() == AffineExprKind::Mod) + return failure(); + + if (bin.getKind() == AffineExprKind::Mul) { + auto dim = dyn_cast(bin.getLHS()); + if (dim) { + strides[dim.getPosition()] = + strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; + return success(); + } + // LHS and RHS may both contain complex expressions of dims. Try one path + // and if it fails try the other. This is guaranteed to succeed because + // only one path may have a `dim`, otherwise this is not an AffineExpr in + // the first place. + if (bin.getLHS().isSymbolicOrConstant()) + return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), + strides, offset); + return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), + strides, offset); + } + + if (bin.getKind() == AffineExprKind::Add) { + auto res1 = + extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); + auto res2 = + extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); + return success(succeeded(res1) && succeeded(res2)); + } + + llvm_unreachable("unexpected binary operation"); +} + +/// A stride specification is a list of integer values that are either static +/// or dynamic (encoded with ShapedType::kDynamic). Strides encode +/// the distance in the number of elements between successive entries along a +/// particular dimension. +/// +/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a +/// non-contiguous memory region of `42` by `16` `f32` elements in which the +/// distance between two consecutive elements along the outer dimension is `1` +/// and the distance between two consecutive elements along the inner dimension +/// is `64`. +/// +/// The convention is that the strides for dimensions d0, .. dn appear in +/// order to make indexing intuitive into the result. +static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef shape, + SmallVectorImpl &strides, + AffineExpr &offset) { + if (m.getNumResults() != 1 && !m.isIdentity()) + return failure(); + + auto zero = getAffineConstantExpr(0, m.getContext()); + auto one = getAffineConstantExpr(1, m.getContext()); + offset = zero; + strides.assign(shape.size(), zero); + + // Canonical case for empty map. + if (m.isIdentity()) { + // 0-D corner case, offset is already 0. + if (shape.empty()) + return success(); + auto stridedExpr = makeCanonicalStridedLayoutExpr(shape, m.getContext()); + if (succeeded(extractStrides(stridedExpr, one, strides, offset))) + return success(); + assert(false && "unexpected failure: extract strides in canonical layout"); + } + + // Non-canonical case requires more work. + auto stridedExpr = + simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); + if (failed(extractStrides(stridedExpr, one, strides, offset))) { + offset = AffineExpr(); + strides.clear(); + return failure(); + } + + // Simplify results to allow folding to constants and simple checks. + unsigned numDims = m.getNumDims(); + unsigned numSymbols = m.getNumSymbols(); + offset = simplifyAffineExpr(offset, numDims, numSymbols); + for (auto &stride : strides) + stride = simplifyAffineExpr(stride, numDims, numSymbols); + + return success(); +} + +LogicalResult mlir::detail::getAffineMapStridesAndOffset( + AffineMap map, ArrayRef shape, SmallVectorImpl &strides, + int64_t &offset) { + AffineExpr offsetExpr; + SmallVector strideExprs; + if (failed(::getStridesAndOffset(map, shape, strideExprs, offsetExpr))) + return failure(); + if (auto cst = llvm::dyn_cast(offsetExpr)) + offset = cst.getValue(); + else + offset = ShapedType::kDynamic; + for (auto e : strideExprs) { + if (auto c = llvm::dyn_cast(e)) + strides.push_back(c.getValue()); + else + strides.push_back(ShapedType::kDynamic); + } + return success(); +} diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index e9af1f77a379e..617dcc222cd6e 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -258,6 +258,15 @@ LogicalResult StridedLayoutAttr::verifyLayout( return success(); } +LogicalResult +StridedLayoutAttr::getStridesAndOffset(ArrayRef, + SmallVectorImpl &strides, + int64_t &offset) const { + llvm::append_range(strides, getStrides()); + offset = getOffset(); + return success(); +} + //===----------------------------------------------------------------------===// // StringAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 3924d082f0628..d47e360e9dc13 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -715,150 +715,9 @@ MemRefType MemRefType::canonicalizeStridedLayout() { return MemRefType::Builder(*this).setLayout({}); } -// Fallback cases for terminal dim/sym/cst that are not part of a binary op ( -// i.e. single term). Accumulate the AffineExpr into the existing one. -static void extractStridesFromTerm(AffineExpr e, - AffineExpr multiplicativeFactor, - MutableArrayRef strides, - AffineExpr &offset) { - if (auto dim = dyn_cast(e)) - strides[dim.getPosition()] = - strides[dim.getPosition()] + multiplicativeFactor; - else - offset = offset + e * multiplicativeFactor; -} - -/// Takes a single AffineExpr `e` and populates the `strides` array with the -/// strides expressions for each dim position. -/// The convention is that the strides for dimensions d0, .. dn appear in -/// order to make indexing intuitive into the result. -static LogicalResult extractStrides(AffineExpr e, - AffineExpr multiplicativeFactor, - MutableArrayRef strides, - AffineExpr &offset) { - auto bin = dyn_cast(e); - if (!bin) { - extractStridesFromTerm(e, multiplicativeFactor, strides, offset); - return success(); - } - - if (bin.getKind() == AffineExprKind::CeilDiv || - bin.getKind() == AffineExprKind::FloorDiv || - bin.getKind() == AffineExprKind::Mod) - return failure(); - - if (bin.getKind() == AffineExprKind::Mul) { - auto dim = dyn_cast(bin.getLHS()); - if (dim) { - strides[dim.getPosition()] = - strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; - return success(); - } - // LHS and RHS may both contain complex expressions of dims. Try one path - // and if it fails try the other. This is guaranteed to succeed because - // only one path may have a `dim`, otherwise this is not an AffineExpr in - // the first place. - if (bin.getLHS().isSymbolicOrConstant()) - return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), - strides, offset); - return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), - strides, offset); - } - - if (bin.getKind() == AffineExprKind::Add) { - auto res1 = - extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); - auto res2 = - extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); - return success(succeeded(res1) && succeeded(res2)); - } - - llvm_unreachable("unexpected binary operation"); -} - -/// A stride specification is a list of integer values that are either static -/// or dynamic (encoded with ShapedType::kDynamic). Strides encode -/// the distance in the number of elements between successive entries along a -/// particular dimension. -/// -/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a -/// non-contiguous memory region of `42` by `16` `f32` elements in which the -/// distance between two consecutive elements along the outer dimension is `1` -/// and the distance between two consecutive elements along the inner dimension -/// is `64`. -/// -/// The convention is that the strides for dimensions d0, .. dn appear in -/// order to make indexing intuitive into the result. -static LogicalResult getStridesAndOffset(MemRefType t, - SmallVectorImpl &strides, - AffineExpr &offset) { - AffineMap m = t.getLayout().getAffineMap(); - - if (m.getNumResults() != 1 && !m.isIdentity()) - return failure(); - - auto zero = getAffineConstantExpr(0, t.getContext()); - auto one = getAffineConstantExpr(1, t.getContext()); - offset = zero; - strides.assign(t.getRank(), zero); - - // Canonical case for empty map. - if (m.isIdentity()) { - // 0-D corner case, offset is already 0. - if (t.getRank() == 0) - return success(); - auto stridedExpr = - makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); - if (succeeded(extractStrides(stridedExpr, one, strides, offset))) - return success(); - assert(false && "unexpected failure: extract strides in canonical layout"); - } - - // Non-canonical case requires more work. - auto stridedExpr = - simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); - if (failed(extractStrides(stridedExpr, one, strides, offset))) { - offset = AffineExpr(); - strides.clear(); - return failure(); - } - - // Simplify results to allow folding to constants and simple checks. - unsigned numDims = m.getNumDims(); - unsigned numSymbols = m.getNumSymbols(); - offset = simplifyAffineExpr(offset, numDims, numSymbols); - for (auto &stride : strides) - stride = simplifyAffineExpr(stride, numDims, numSymbols); - - return success(); -} - LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl &strides, int64_t &offset) { - // Happy path: the type uses the strided layout directly. - if (auto strided = llvm::dyn_cast(getLayout())) { - llvm::append_range(strides, strided.getStrides()); - offset = strided.getOffset(); - return success(); - } - - // Otherwise, defer to the affine fallback as layouts are supposed to be - // convertible to affine maps. - AffineExpr offsetExpr; - SmallVector strideExprs; - if (failed(::getStridesAndOffset(*this, strideExprs, offsetExpr))) - return failure(); - if (auto cst = llvm::dyn_cast(offsetExpr)) - offset = cst.getValue(); - else - offset = ShapedType::kDynamic; - for (auto e : strideExprs) { - if (auto c = llvm::dyn_cast(e)) - strides.push_back(c.getValue()); - else - strides.push_back(ShapedType::kDynamic); - } - return success(); + return getLayout().getStridesAndOffset(getShape(), strides, offset); } std::pair, int64_t> MemRefType::getStridesAndOffset() {