Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ LogicalResult
verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
function_ref<InFlightDiagnostic()> emitError);

// Return the strides and offsets that can be inferred from the given affine
// layout map given the map and a memref shape.
LogicalResult getAffineMapStridesAndOffset(AffineMap map,
ArrayRef<int64_t> shape,
SmallVectorImpl<int64_t> &strides,
int64_t &offset);
} // namespace detail

} // namespace mlir
Expand Down
17 changes: 17 additions & 0 deletions mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,23 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
return ::mlir::detail::verifyAffineMapAsLayout($_attr.getAffineMap(),
shape, emitError);
}]
>,

InterfaceMethod<
[{Return the strides (using ShapedType::kDynamic for the dynamic case)
that this layout corresponds to into `strides` and `offset` if such exist
and can be detirmined from a combination of the layout and the given
`shape`. If these strides cannot be inferred, return failure().
The values of `strides` and `offset` are undefined on failure.}],
"::llvm::LogicalResult", "getStridesAndOffset",
(ins "::llvm::ArrayRef<int64_t>":$shape,
"::llvm::SmallVectorImpl<int64_t>&":$strides,
"int64_t&":$offset),
[{}],
[{
return ::mlir::detail::getAffineMapStridesAndOffset(
$_attr.getAffineMap(), shape, strides, offset);
}]
>
];
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/BuiltinAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<

def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout",
[DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface,
["verifyLayout"]>]> {
["verifyLayout", "getStridesAndOffset"]>]> {
let summary = "An Attribute representing a strided layout of a shaped type";
let description = [{
Syntax:
Expand Down
135 changes: 135 additions & 0 deletions mlir/lib/IR/BuiltinAttributeInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,138 @@ LogicalResult mlir::detail::verifyAffineMapAsLayout(

return success();
}

// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
// i.e. single term). Accumulate the AffineExpr into the existing one.
static void extractStridesFromTerm(AffineExpr e,
AffineExpr multiplicativeFactor,
MutableArrayRef<AffineExpr> strides,
AffineExpr &offset) {
if (auto dim = dyn_cast<AffineDimExpr>(e))
strides[dim.getPosition()] =
strides[dim.getPosition()] + multiplicativeFactor;
else
offset = offset + e * multiplicativeFactor;
}

/// Takes a single AffineExpr `e` and populates the `strides` array with the
/// strides expressions for each dim position.
/// The convention is that the strides for dimensions d0, .. dn appear in
/// order to make indexing intuitive into the result.
static LogicalResult extractStrides(AffineExpr e,
AffineExpr multiplicativeFactor,
MutableArrayRef<AffineExpr> strides,
AffineExpr &offset) {
auto bin = dyn_cast<AffineBinaryOpExpr>(e);
if (!bin) {
extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
return success();
}

if (bin.getKind() == AffineExprKind::CeilDiv ||
bin.getKind() == AffineExprKind::FloorDiv ||
bin.getKind() == AffineExprKind::Mod)
return failure();

if (bin.getKind() == AffineExprKind::Mul) {
auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
if (dim) {
strides[dim.getPosition()] =
strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
return success();
}
// LHS and RHS may both contain complex expressions of dims. Try one path
// and if it fails try the other. This is guaranteed to succeed because
// only one path may have a `dim`, otherwise this is not an AffineExpr in
// the first place.
if (bin.getLHS().isSymbolicOrConstant())
return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
strides, offset);
return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
strides, offset);
}

if (bin.getKind() == AffineExprKind::Add) {
auto res1 =
extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
auto res2 =
extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
return success(succeeded(res1) && succeeded(res2));
}

llvm_unreachable("unexpected binary operation");
}

/// A stride specification is a list of integer values that are either static
/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
/// the distance in the number of elements between successive entries along a
/// particular dimension.
///
/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
/// non-contiguous memory region of `42` by `16` `f32` elements in which the
/// distance between two consecutive elements along the outer dimension is `1`
/// and the distance between two consecutive elements along the inner dimension
/// is `64`.
///
/// The convention is that the strides for dimensions d0, .. dn appear in
/// order to make indexing intuitive into the result.
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef<int64_t> shape,
SmallVectorImpl<AffineExpr> &strides,
AffineExpr &offset) {
if (m.getNumResults() != 1 && !m.isIdentity())
return failure();

auto zero = getAffineConstantExpr(0, m.getContext());
auto one = getAffineConstantExpr(1, m.getContext());
offset = zero;
strides.assign(shape.size(), zero);

// Canonical case for empty map.
if (m.isIdentity()) {
// 0-D corner case, offset is already 0.
if (shape.empty())
return success();
auto stridedExpr = makeCanonicalStridedLayoutExpr(shape, m.getContext());
if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
return success();
assert(false && "unexpected failure: extract strides in canonical layout");
}

// Non-canonical case requires more work.
auto stridedExpr =
simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
if (failed(extractStrides(stridedExpr, one, strides, offset))) {
offset = AffineExpr();
strides.clear();
return failure();
}

// Simplify results to allow folding to constants and simple checks.
unsigned numDims = m.getNumDims();
unsigned numSymbols = m.getNumSymbols();
offset = simplifyAffineExpr(offset, numDims, numSymbols);
for (auto &stride : strides)
stride = simplifyAffineExpr(stride, numDims, numSymbols);

return success();
}

LogicalResult mlir::detail::getAffineMapStridesAndOffset(
AffineMap map, ArrayRef<int64_t> shape, SmallVectorImpl<int64_t> &strides,
int64_t &offset) {
AffineExpr offsetExpr;
SmallVector<AffineExpr, 4> strideExprs;
if (failed(::getStridesAndOffset(map, shape, strideExprs, offsetExpr)))
return failure();
if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
offset = cst.getValue();
else
offset = ShapedType::kDynamic;
for (auto e : strideExprs) {
if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
strides.push_back(c.getValue());
else
strides.push_back(ShapedType::kDynamic);
}
return success();
}
9 changes: 9 additions & 0 deletions mlir/lib/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,15 @@ LogicalResult StridedLayoutAttr::verifyLayout(
return success();
}

LogicalResult
StridedLayoutAttr::getStridesAndOffset(ArrayRef<int64_t>,
SmallVectorImpl<int64_t> &strides,
int64_t &offset) const {
llvm::append_range(strides, getStrides());
offset = getOffset();
return success();
}

//===----------------------------------------------------------------------===//
// StringAttr
//===----------------------------------------------------------------------===//
Expand Down
143 changes: 1 addition & 142 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,150 +715,9 @@ MemRefType MemRefType::canonicalizeStridedLayout() {
return MemRefType::Builder(*this).setLayout({});
}

// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
// i.e. single term). Accumulate the AffineExpr into the existing one.
static void extractStridesFromTerm(AffineExpr e,
AffineExpr multiplicativeFactor,
MutableArrayRef<AffineExpr> strides,
AffineExpr &offset) {
if (auto dim = dyn_cast<AffineDimExpr>(e))
strides[dim.getPosition()] =
strides[dim.getPosition()] + multiplicativeFactor;
else
offset = offset + e * multiplicativeFactor;
}

/// Takes a single AffineExpr `e` and populates the `strides` array with the
/// strides expressions for each dim position.
/// The convention is that the strides for dimensions d0, .. dn appear in
/// order to make indexing intuitive into the result.
static LogicalResult extractStrides(AffineExpr e,
AffineExpr multiplicativeFactor,
MutableArrayRef<AffineExpr> strides,
AffineExpr &offset) {
auto bin = dyn_cast<AffineBinaryOpExpr>(e);
if (!bin) {
extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
return success();
}

if (bin.getKind() == AffineExprKind::CeilDiv ||
bin.getKind() == AffineExprKind::FloorDiv ||
bin.getKind() == AffineExprKind::Mod)
return failure();

if (bin.getKind() == AffineExprKind::Mul) {
auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
if (dim) {
strides[dim.getPosition()] =
strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
return success();
}
// LHS and RHS may both contain complex expressions of dims. Try one path
// and if it fails try the other. This is guaranteed to succeed because
// only one path may have a `dim`, otherwise this is not an AffineExpr in
// the first place.
if (bin.getLHS().isSymbolicOrConstant())
return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
strides, offset);
return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
strides, offset);
}

if (bin.getKind() == AffineExprKind::Add) {
auto res1 =
extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
auto res2 =
extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
return success(succeeded(res1) && succeeded(res2));
}

llvm_unreachable("unexpected binary operation");
}

/// A stride specification is a list of integer values that are either static
/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
/// the distance in the number of elements between successive entries along a
/// particular dimension.
///
/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
/// non-contiguous memory region of `42` by `16` `f32` elements in which the
/// distance between two consecutive elements along the outer dimension is `1`
/// and the distance between two consecutive elements along the inner dimension
/// is `64`.
///
/// The convention is that the strides for dimensions d0, .. dn appear in
/// order to make indexing intuitive into the result.
static LogicalResult getStridesAndOffset(MemRefType t,
SmallVectorImpl<AffineExpr> &strides,
AffineExpr &offset) {
AffineMap m = t.getLayout().getAffineMap();

if (m.getNumResults() != 1 && !m.isIdentity())
return failure();

auto zero = getAffineConstantExpr(0, t.getContext());
auto one = getAffineConstantExpr(1, t.getContext());
offset = zero;
strides.assign(t.getRank(), zero);

// Canonical case for empty map.
if (m.isIdentity()) {
// 0-D corner case, offset is already 0.
if (t.getRank() == 0)
return success();
auto stridedExpr =
makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
return success();
assert(false && "unexpected failure: extract strides in canonical layout");
}

// Non-canonical case requires more work.
auto stridedExpr =
simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
if (failed(extractStrides(stridedExpr, one, strides, offset))) {
offset = AffineExpr();
strides.clear();
return failure();
}

// Simplify results to allow folding to constants and simple checks.
unsigned numDims = m.getNumDims();
unsigned numSymbols = m.getNumSymbols();
offset = simplifyAffineExpr(offset, numDims, numSymbols);
for (auto &stride : strides)
stride = simplifyAffineExpr(stride, numDims, numSymbols);

return success();
}

LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
int64_t &offset) {
// Happy path: the type uses the strided layout directly.
if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(getLayout())) {
llvm::append_range(strides, strided.getStrides());
offset = strided.getOffset();
return success();
}

// Otherwise, defer to the affine fallback as layouts are supposed to be
// convertible to affine maps.
AffineExpr offsetExpr;
SmallVector<AffineExpr, 4> strideExprs;
if (failed(::getStridesAndOffset(*this, strideExprs, offsetExpr)))
return failure();
if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
offset = cst.getValue();
else
offset = ShapedType::kDynamic;
for (auto e : strideExprs) {
if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
strides.push_back(c.getValue());
else
strides.push_back(ShapedType::kDynamic);
}
return success();
return getLayout().getStridesAndOffset(getShape(), strides, offset);
}

std::pair<SmallVector<int64_t>, int64_t> MemRefType::getStridesAndOffset() {
Expand Down