Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 139 additions & 117 deletions mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
Expand Down Expand Up @@ -50,28 +51,71 @@ static Value createMul(Location loc, Value x, Value y, Type accType,
return arith::MulFOp::create(builder, loc, xConvert, yConvert);
}

// Delinearizes the given composite `index` by the basis specified in `factors`.
static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
ArrayRef<int64_t> factors) {
assert(!factors.empty() && "empty factor list");
SmallVector<Value> basis;
for (int64_t f : factors)
basis.push_back(arith::ConstantOp::create(b, loc, b.getIndexAttr(f)));
FailureOr<SmallVector<Value>> multiIndex =
affine::delinearizeIndex(b, loc, index, basis);
assert(!failed(multiIndex) && "Failed to linearize img2col index");
return *multiIndex;
// Generate the affine expression to compute the convolved index
// for the input as `oIndex * stride + fIndex`,
// where oIndex: output iterator; fIndex: filter iterator.
static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride,
bool useSymbols = true) {
AffineExpr oExpr, fExpr;
if (useSymbols)
bindSymbols(b.getContext(), oExpr, fExpr);
else
bindDims(b.getContext(), oExpr, fExpr);
return AffineExpr(stride * oExpr + fExpr);
}

// Given indices corresponding to iterators in the output (oIndex) and filter
// (fIndex) for a convolution, compute the convolved index for the
// input as `oIndex * stride + fIndex`.
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
Value fIndex, int64_t stride) {
AffineExpr oExpr, fExpr;
bindSymbols(b.getContext(), oExpr, fExpr);
AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
// Stores the affine expressions to map the iteration space of the im2col matrix
// to the corresponding indices of the output and filter matrices
struct Im2ColToOperandsExprs {
AffineExpr fhIndex;
AffineExpr fwIndex;
AffineExpr icIndex;
AffineExpr ohIndex;
AffineExpr owIndex;
};

// Stores the affine expressions to map the iteration space of the im2col matrix
// to the input matrix indices
struct Im2ColToInputDimsExprs {
AffineExpr bIndex;
AffineExpr hIndex;
AffineExpr wIndex;
AffineExpr cIndex;
};

/// Construct the affine expressions that map the indices of the im2col matrix
/// to the corresponding input tensor indices for a 2D convolution with the the
/// provided strides.
///
/// @param exprs Affine expressions for output and filter indices.
/// @param strides [height, width] stride values for the convolution.
/// @param rewriter Pattern rewriter.
/// @return Affine expressions mapping im2col matrix indices to input
/// offsets.
static Im2ColToInputDimsExprs
getIm2ColInputExpressions(Im2ColToOperandsExprs exprs,
ArrayRef<int64_t> strides, RewriterBase &rewriter) {
// maps the iteration space of the im2col matrix to (output_y, filter_y)
auto hIndicesMap = AffineMap::inferFromExprList(
{ArrayRef{exprs.ohIndex, exprs.fhIndex}}, rewriter.getContext())[0];
// maps the iteration space of the im2col matrix to (output_x, filter_x)
auto wIndicesMap = AffineMap::inferFromExprList(
{ArrayRef{exprs.owIndex, exprs.fwIndex}}, rewriter.getContext())[0];
// Compute the input indexing map, to map the indices of the im2col matrix to
// the original input offsets. Each element of the im2col matrix corresponds
// to a pair of (out_element, filter_element). First, we build the expressions
// to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs;
// then we compose them with the maps that map the im2col matrix elements to
// the (out_element, filter_element) pairs.
auto bIndexExpr = rewriter.getAffineDimExpr(0U);
auto hIndexExpr = getConvolvedExpr(rewriter, strides[0],
/*useSymbols*/ false);
hIndexExpr = hIndexExpr.compose(hIndicesMap);
auto wIndexExpr = getConvolvedExpr(rewriter, strides[1],
/*useSymbols*/ false);
wIndexExpr = wIndexExpr.compose(wIndicesMap);
auto cIndexExpr = exprs.icIndex;
return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
}

FailureOr<std::pair<Operation *, Operation *>>
Expand Down Expand Up @@ -135,44 +179,37 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);

// Given an index of the im2col matrix, retrieve the corresponding indices of
// the output and filter matrices
auto mIndicesExprs =
delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
ArrayRef<int64_t>{fw * ic, ic, 1});
Im2ColToOperandsExprs i2cToOperExprs;
i2cToOperExprs.fhIndex = kIndicesExprs[0];
i2cToOperExprs.fwIndex = kIndicesExprs[1];
i2cToOperExprs.icIndex = kIndicesExprs[2];
i2cToOperExprs.ohIndex = mIndicesExprs[0];
i2cToOperExprs.owIndex = mIndicesExprs[1];

// im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
rewriter);
auto inMap =
AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex,
inExprs.wIndex, inExprs.cIndex}},
rewriter.getContext())[0];

SmallVector<AffineMap> img2colIndexingMaps = {
AffineMap::getMultiDimIdentityMap(nloops, context)};
inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};

auto img2ColTensor = linalg::GenericOp::create(
rewriter, loc, colTensor.getType(),
/*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
/*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
// Get the iterators named based on the matmul (batch, m, k).
Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);

// Recover the original iteration indices from the problem/input sizes.
SmallVector<Value> mIndices = unrollIndex(
nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
auto ohIndex = mIndices[0];
auto owIndex = mIndices[1];

SmallVector<Value> kIndices = unrollIndex(
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
auto fhIndex = kIndices[0];
auto fwIndex = kIndices[1];
auto icIndex = kIndices[2];

// Extract the input element corresponding to the expanded indices.
Value hIndex =
getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
convOp.getStrides().getValues<int64_t>()[0]);
Value wIndex =
getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
convOp.getStrides().getValues<int64_t>()[1]);

// im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
extractionIndices);
linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
});

// Because the filter does not share the same batch dimension,
Expand Down Expand Up @@ -421,44 +458,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);

SmallVector<AffineMap, 4> img2colIndexingMaps = {
AffineMap::getMultiDimIdentityMap(nloops, context)};
// Recover the original iteration indices from the problem/input sizes:
// given an index of the im2col matrix, retrieve the corresponding indices of
// the output and filter matrices
auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(1U),
ArrayRef<int64_t>{fh * fw, fw, 1});
auto mIndicesExprs =
delinearize(rewriter.getAffineDimExpr(2U), ArrayRef<int64_t>{ow, 1});
Im2ColToOperandsExprs i2cToOperExprs;
i2cToOperExprs.icIndex = kIndicesExprs[0];
i2cToOperExprs.fhIndex = kIndicesExprs[1];
i2cToOperExprs.fwIndex = kIndicesExprs[2];
i2cToOperExprs.ohIndex = mIndicesExprs[0];
i2cToOperExprs.owIndex = mIndicesExprs[1];
Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
rewriter);
auto inMap =
AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.cIndex,
inExprs.hIndex, inExprs.wIndex}},
rewriter.getContext())[0];
// im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
SmallVector<AffineMap> img2colIndexingMaps = {
inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};

auto img2ColTensor = linalg::GenericOp::create(
rewriter, loc, colTensor.getType(),
/*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
/*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
// Get the iterators named based on the matmul (batch, m, k).
Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
Value nIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);

// Recover the original iteration indices from the problem/input sizes.
SmallVector<Value> kIndices = unrollIndex(
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
auto icIndex = kIndices[0];
auto fhIndex = kIndices[1];
auto fwIndex = kIndices[2];

SmallVector<Value> nIndices = unrollIndex(
nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
auto ohIndex = nIndices[0];
auto owIndex = nIndices[1];

// Extract the input element corresponding to the expanded indices.
Value hIndex =
getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
convOp.getStrides().getValues<int64_t>()[0]);
Value wIndex =
getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
convOp.getStrides().getValues<int64_t>()[1]);

// im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
extractionIndices);
linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
});

// Because the filter does not share the same batch dimension,
Expand Down Expand Up @@ -545,6 +574,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
Value reshapedOutput = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedOutputType, output, outputReassocIndices);

// Shape of the Toeplitz matrix produced by Im2col.
SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
inputType.getElementType());
Expand All @@ -556,44 +586,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);

// Given an index of the im2col matrix, retrieve the corresponding indices of
// the output and filter matrices
auto mIndicesExprs =
delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
ArrayRef<int64_t>{fw * ic, ic, 1});
Im2ColToOperandsExprs i2cToOperExprs;
i2cToOperExprs.fhIndex = kIndicesExprs[0];
i2cToOperExprs.fwIndex = kIndicesExprs[1];
i2cToOperExprs.icIndex = kIndicesExprs[2];
i2cToOperExprs.ohIndex = mIndicesExprs[0];
i2cToOperExprs.owIndex = mIndicesExprs[1];

// im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
rewriter);
auto inMap =
AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex,
inExprs.wIndex, inExprs.cIndex}},
rewriter.getContext())[0];
SmallVector<AffineMap> img2colIndexingMaps = {
AffineMap::getMultiDimIdentityMap(nloops, context)};
inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};

auto img2ColTensor = linalg::GenericOp::create(
rewriter, loc, colTensor.getType(),
/*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
/*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
// Get the iterators named based on the matmul (batch, m, k).
Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);

// Recover the original iteration indices from the problem/input sizes.
SmallVector<Value> mIndices = unrollIndex(
nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
auto ohIndex = mIndices[0];
auto owIndex = mIndices[1];

SmallVector<Value> kIndices = unrollIndex(
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
auto fhIndex = kIndices[0];
auto fwIndex = kIndices[1];
auto icIndex = kIndices[2];

// Extract the input element corresponding to the expanded indices.
Value hIndex =
getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
convOp.getStrides().getValues<int64_t>()[0]);
Value wIndex =
getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
convOp.getStrides().getValues<int64_t>()[1]);

// im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
extractionIndices);
linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
});

// Because we didn't transpose the filters we don't actually have a batched
Expand Down
Loading