Skip to content

Commit 57a4727

Browse files
[mlir][linalg] Produce canonical linalg.generic for im2col (llvm#134675)
Before this patch, the Img2Col transform produced a non-canonical linalg.generic whose input tensor was not reported in the inputs of the operation: instead, it was accessed manually from inside the op body, after an internal calculation of the access offsets. This patch modifies the Im2Col rewrite to produce a canonical linalg.generic whose input is correctly reported in its 'ins()', whose access offsets are computed through an indexing map, and whose body contains only a 'linalg.yield' op. Signed-off-by: Fabrizio Indirli <[email protected]> Co-authored-by: Georgios Pinitas <[email protected]>
1 parent c767ee1 commit 57a4727

File tree

2 files changed

+185
-197
lines changed

2 files changed

+185
-197
lines changed

mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp

Lines changed: 139 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1313
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1414
#include "mlir/Dialect/Tensor/IR/Tensor.h"
15+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1516
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1617
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1718
#include "mlir/IR/AffineExpr.h"
@@ -50,28 +51,71 @@ static Value createMul(Location loc, Value x, Value y, Type accType,
5051
return arith::MulFOp::create(builder, loc, xConvert, yConvert);
5152
}
5253

53-
// Delinearizes the given composite `index` by the basis specified in `factors`.
54-
static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
55-
ArrayRef<int64_t> factors) {
56-
assert(!factors.empty() && "empty factor list");
57-
SmallVector<Value> basis;
58-
for (int64_t f : factors)
59-
basis.push_back(arith::ConstantOp::create(b, loc, b.getIndexAttr(f)));
60-
FailureOr<SmallVector<Value>> multiIndex =
61-
affine::delinearizeIndex(b, loc, index, basis);
62-
assert(!failed(multiIndex) && "Failed to linearize img2col index");
63-
return *multiIndex;
54+
// Generate the affine expression to compute the convolved index
55+
// for the input as `oIndex * stride + fIndex`,
56+
// where oIndex: output iterator; fIndex: filter iterator.
57+
static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride,
58+
bool useSymbols = true) {
59+
AffineExpr oExpr, fExpr;
60+
if (useSymbols)
61+
bindSymbols(b.getContext(), oExpr, fExpr);
62+
else
63+
bindDims(b.getContext(), oExpr, fExpr);
64+
return AffineExpr(stride * oExpr + fExpr);
6465
}
6566

66-
// Given indices corresponding to iterators in the output (oIndex) and filter
67-
// (fIndex) for a convolution, compute the convolved index for the
68-
// input as `oIndex * stride + fIndex`.
69-
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
70-
Value fIndex, int64_t stride) {
71-
AffineExpr oExpr, fExpr;
72-
bindSymbols(b.getContext(), oExpr, fExpr);
73-
AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
74-
return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
67+
// Stores the affine expressions to map the iteration space of the im2col matrix
68+
// to the corresponding indices of the output and filter matrices
69+
struct Im2ColToOperandsExprs {
70+
AffineExpr fhIndex;
71+
AffineExpr fwIndex;
72+
AffineExpr icIndex;
73+
AffineExpr ohIndex;
74+
AffineExpr owIndex;
75+
};
76+
77+
// Stores the affine expressions to map the iteration space of the im2col matrix
78+
// to the input matrix indices
79+
struct Im2ColToInputDimsExprs {
80+
AffineExpr bIndex;
81+
AffineExpr hIndex;
82+
AffineExpr wIndex;
83+
AffineExpr cIndex;
84+
};
85+
86+
/// Construct the affine expressions that map the indices of the im2col matrix
87+
/// to the corresponding input tensor indices for a 2D convolution with the the
88+
/// provided strides.
89+
///
90+
/// @param exprs Affine expressions for output and filter indices.
91+
/// @param strides [height, width] stride values for the convolution.
92+
/// @param rewriter Pattern rewriter.
93+
/// @return Affine expressions mapping im2col matrix indices to input
94+
/// offsets.
95+
static Im2ColToInputDimsExprs
96+
getIm2ColInputExpressions(Im2ColToOperandsExprs exprs,
97+
ArrayRef<int64_t> strides, RewriterBase &rewriter) {
98+
// maps the iteration space of the im2col matrix to (output_y, filter_y)
99+
auto hIndicesMap = AffineMap::inferFromExprList(
100+
{ArrayRef{exprs.ohIndex, exprs.fhIndex}}, rewriter.getContext())[0];
101+
// maps the iteration space of the im2col matrix to (output_x, filter_x)
102+
auto wIndicesMap = AffineMap::inferFromExprList(
103+
{ArrayRef{exprs.owIndex, exprs.fwIndex}}, rewriter.getContext())[0];
104+
// Compute the input indexing map, to map the indices of the im2col matrix to
105+
// the original input offsets. Each element of the im2col matrix corresponds
106+
// to a pair of (out_element, filter_element). First, we build the expressions
107+
// to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs;
108+
// then we compose them with the maps that map the im2col matrix elements to
109+
// the (out_element, filter_element) pairs.
110+
auto bIndexExpr = rewriter.getAffineDimExpr(0U);
111+
auto hIndexExpr = getConvolvedExpr(rewriter, strides[0],
112+
/*useSymbols*/ false);
113+
hIndexExpr = hIndexExpr.compose(hIndicesMap);
114+
auto wIndexExpr = getConvolvedExpr(rewriter, strides[1],
115+
/*useSymbols*/ false);
116+
wIndexExpr = wIndexExpr.compose(wIndicesMap);
117+
auto cIndexExpr = exprs.icIndex;
118+
return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
75119
}
76120

77121
FailureOr<std::pair<Operation *, Operation *>>
@@ -135,44 +179,37 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
135179
auto reduction = utils::IteratorType::reduction;
136180
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
137181

182+
// Given an index of the im2col matrix, retrieve the corresponding indices of
183+
// the output and filter matrices
184+
auto mIndicesExprs =
185+
delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
186+
auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
187+
ArrayRef<int64_t>{fw * ic, ic, 1});
188+
Im2ColToOperandsExprs i2cToOperExprs;
189+
i2cToOperExprs.fhIndex = kIndicesExprs[0];
190+
i2cToOperExprs.fwIndex = kIndicesExprs[1];
191+
i2cToOperExprs.icIndex = kIndicesExprs[2];
192+
i2cToOperExprs.ohIndex = mIndicesExprs[0];
193+
i2cToOperExprs.owIndex = mIndicesExprs[1];
194+
195+
// im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
196+
Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
197+
i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
198+
rewriter);
199+
auto inMap =
200+
AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex,
201+
inExprs.wIndex, inExprs.cIndex}},
202+
rewriter.getContext())[0];
203+
138204
SmallVector<AffineMap> img2colIndexingMaps = {
139-
AffineMap::getMultiDimIdentityMap(nloops, context)};
205+
inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
140206

141207
auto img2ColTensor = linalg::GenericOp::create(
142208
rewriter, loc, colTensor.getType(),
143-
/*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
209+
/*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
144210
img2colIterators,
145211
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
146-
// Get the iterators named based on the matmul (batch, m, k).
147-
Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
148-
Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
149-
Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
150-
151-
// Recover the original iteration indices from the problem/input sizes.
152-
SmallVector<Value> mIndices = unrollIndex(
153-
nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
154-
auto ohIndex = mIndices[0];
155-
auto owIndex = mIndices[1];
156-
157-
SmallVector<Value> kIndices = unrollIndex(
158-
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
159-
auto fhIndex = kIndices[0];
160-
auto fwIndex = kIndices[1];
161-
auto icIndex = kIndices[2];
162-
163-
// Extract the input element corresponding to the expanded indices.
164-
Value hIndex =
165-
getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
166-
convOp.getStrides().getValues<int64_t>()[0]);
167-
Value wIndex =
168-
getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
169-
convOp.getStrides().getValues<int64_t>()[1]);
170-
171-
// im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
172-
SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
173-
Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
174-
extractionIndices);
175-
linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
212+
linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
176213
});
177214

178215
// Because the filter does not share the same batch dimension,
@@ -421,44 +458,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
421458
auto reduction = utils::IteratorType::reduction;
422459
SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
423460

424-
SmallVector<AffineMap, 4> img2colIndexingMaps = {
425-
AffineMap::getMultiDimIdentityMap(nloops, context)};
461+
// Recover the original iteration indices from the problem/input sizes:
462+
// given an index of the im2col matrix, retrieve the corresponding indices of
463+
// the output and filter matrices
464+
auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(1U),
465+
ArrayRef<int64_t>{fh * fw, fw, 1});
466+
auto mIndicesExprs =
467+
delinearize(rewriter.getAffineDimExpr(2U), ArrayRef<int64_t>{ow, 1});
468+
Im2ColToOperandsExprs i2cToOperExprs;
469+
i2cToOperExprs.icIndex = kIndicesExprs[0];
470+
i2cToOperExprs.fhIndex = kIndicesExprs[1];
471+
i2cToOperExprs.fwIndex = kIndicesExprs[2];
472+
i2cToOperExprs.ohIndex = mIndicesExprs[0];
473+
i2cToOperExprs.owIndex = mIndicesExprs[1];
474+
Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
475+
i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
476+
rewriter);
477+
auto inMap =
478+
AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.cIndex,
479+
inExprs.hIndex, inExprs.wIndex}},
480+
rewriter.getContext())[0];
481+
// im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
482+
SmallVector<AffineMap> img2colIndexingMaps = {
483+
inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
426484

427485
auto img2ColTensor = linalg::GenericOp::create(
428486
rewriter, loc, colTensor.getType(),
429-
/*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
487+
/*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
430488
img2colIterators,
431489
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
432-
// Get the iterators named based on the matmul (batch, m, k).
433-
Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
434-
Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
435-
Value nIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
436-
437-
// Recover the original iteration indices from the problem/input sizes.
438-
SmallVector<Value> kIndices = unrollIndex(
439-
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
440-
auto icIndex = kIndices[0];
441-
auto fhIndex = kIndices[1];
442-
auto fwIndex = kIndices[2];
443-
444-
SmallVector<Value> nIndices = unrollIndex(
445-
nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
446-
auto ohIndex = nIndices[0];
447-
auto owIndex = nIndices[1];
448-
449-
// Extract the input element corresponding to the expanded indices.
450-
Value hIndex =
451-
getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
452-
convOp.getStrides().getValues<int64_t>()[0]);
453-
Value wIndex =
454-
getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
455-
convOp.getStrides().getValues<int64_t>()[1]);
456-
457-
// im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
458-
SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
459-
Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
460-
extractionIndices);
461-
linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
490+
linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
462491
});
463492

464493
// Because the filter does not share the same batch dimension,
@@ -545,6 +574,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
545574
Value reshapedOutput = tensor::CollapseShapeOp::create(
546575
rewriter, loc, reshapedOutputType, output, outputReassocIndices);
547576

577+
// Shape of the Toeplitz matrix produced by Im2col.
548578
SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
549579
Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
550580
inputType.getElementType());
@@ -556,44 +586,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
556586
auto reduction = utils::IteratorType::reduction;
557587
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
558588

589+
// Given an index of the im2col matrix, retrieve the corresponding indices of
590+
// the output and filter matrices
591+
auto mIndicesExprs =
592+
delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
593+
auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
594+
ArrayRef<int64_t>{fw * ic, ic, 1});
595+
Im2ColToOperandsExprs i2cToOperExprs;
596+
i2cToOperExprs.fhIndex = kIndicesExprs[0];
597+
i2cToOperExprs.fwIndex = kIndicesExprs[1];
598+
i2cToOperExprs.icIndex = kIndicesExprs[2];
599+
i2cToOperExprs.ohIndex = mIndicesExprs[0];
600+
i2cToOperExprs.owIndex = mIndicesExprs[1];
601+
602+
// im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
603+
Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
604+
i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
605+
rewriter);
606+
auto inMap =
607+
AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex,
608+
inExprs.wIndex, inExprs.cIndex}},
609+
rewriter.getContext())[0];
559610
SmallVector<AffineMap> img2colIndexingMaps = {
560-
AffineMap::getMultiDimIdentityMap(nloops, context)};
611+
inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
561612

562613
auto img2ColTensor = linalg::GenericOp::create(
563614
rewriter, loc, colTensor.getType(),
564-
/*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
615+
/*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
565616
img2colIterators,
566617
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
567-
// Get the iterators named based on the matmul (batch, m, k).
568-
Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
569-
Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
570-
Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
571-
572-
// Recover the original iteration indices from the problem/input sizes.
573-
SmallVector<Value> mIndices = unrollIndex(
574-
nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
575-
auto ohIndex = mIndices[0];
576-
auto owIndex = mIndices[1];
577-
578-
SmallVector<Value> kIndices = unrollIndex(
579-
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
580-
auto fhIndex = kIndices[0];
581-
auto fwIndex = kIndices[1];
582-
auto icIndex = kIndices[2];
583-
584-
// Extract the input element corresponding to the expanded indices.
585-
Value hIndex =
586-
getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
587-
convOp.getStrides().getValues<int64_t>()[0]);
588-
Value wIndex =
589-
getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
590-
convOp.getStrides().getValues<int64_t>()[1]);
591-
592-
// im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
593-
SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
594-
Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
595-
extractionIndices);
596-
linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
618+
linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
597619
});
598620

599621
// Because we didn't transpose the filters we don't actually have a batched

0 commit comments

Comments
 (0)