1212#include " mlir/Dialect/Linalg/IR/Linalg.h"
1313#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
1414#include " mlir/Dialect/Tensor/IR/Tensor.h"
15+ #include " mlir/Dialect/Utils/IndexingUtils.h"
1516#include " mlir/Dialect/Utils/ReshapeOpsUtils.h"
1617#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
1718#include " mlir/IR/AffineExpr.h"
@@ -50,28 +51,71 @@ static Value createMul(Location loc, Value x, Value y, Type accType,
5051 return arith::MulFOp::create (builder, loc, xConvert, yConvert);
5152}
5253
53- // Delinearizes the given composite `index` by the basis specified in `factors`.
54- static SmallVector<Value> unrollIndex (OpBuilder &b, Location loc, Value index ,
55- ArrayRef< int64_t > factors) {
56- assert (!factors. empty () && " empty factor list " );
57- SmallVector<Value> basis;
58- for ( int64_t f : factors)
59- basis. push_back ( arith::ConstantOp::create (b, loc, b. getIndexAttr (f)));
60- FailureOr<SmallVector<Value>> multiIndex =
61- affine::delinearizeIndex (b, loc, index, basis);
62- assert (! failed (multiIndex) && " Failed to linearize img2col index " );
63- return *multiIndex ;
54+ // Generate the affine expression to compute the convolved index
55+ // for the input as `oIndex * stride + fIndex` ,
56+ // where oIndex: output iterator; fIndex: filter iterator.
57+ static AffineExpr getConvolvedExpr (OpBuilder &b, int64_t stride,
58+ bool useSymbols = true ) {
59+ AffineExpr oExpr, fExpr ;
60+ if (useSymbols)
61+ bindSymbols (b. getContext (), oExpr, fExpr );
62+ else
63+ bindDims (b. getContext (), oExpr, fExpr );
64+ return AffineExpr (stride * oExpr + fExpr ) ;
6465}
6566
66- // Given indices corresponding to iterators in the output (oIndex) and filter
67- // (fIndex) for a convolution, compute the convolved index for the
68- // input as `oIndex * stride + fIndex`.
69- static Value getConvolvedIndex (OpBuilder &b, Location loc, Value oIndex,
70- Value fIndex , int64_t stride) {
71- AffineExpr oExpr, fExpr ;
72- bindSymbols (b.getContext (), oExpr, fExpr );
73- AffineMap convMap = AffineMap::get (0 , 2 , stride * oExpr + fExpr );
74- return affine::makeComposedAffineApply (b, loc, convMap, {oIndex, fIndex });
67+ // Stores the affine expressions to map the iteration space of the im2col matrix
68+ // to the corresponding indices of the output and filter matrices
69+ struct Im2ColToOperandsExprs {
70+ AffineExpr fhIndex;
71+ AffineExpr fwIndex;
72+ AffineExpr icIndex;
73+ AffineExpr ohIndex;
74+ AffineExpr owIndex;
75+ };
76+
77+ // Stores the affine expressions to map the iteration space of the im2col matrix
78+ // to the input matrix indices
79+ struct Im2ColToInputDimsExprs {
80+ AffineExpr bIndex;
81+ AffineExpr hIndex;
82+ AffineExpr wIndex;
83+ AffineExpr cIndex;
84+ };
85+
86+ // / Construct the affine expressions that map the indices of the im2col matrix
87+ // / to the corresponding input tensor indices for a 2D convolution with the the
88+ // / provided strides.
89+ // /
90+ // / @param exprs Affine expressions for output and filter indices.
91+ // / @param strides [height, width] stride values for the convolution.
92+ // / @param rewriter Pattern rewriter.
93+ // / @return Affine expressions mapping im2col matrix indices to input
94+ // / offsets.
95+ static Im2ColToInputDimsExprs
96+ getIm2ColInputExpressions (Im2ColToOperandsExprs exprs,
97+ ArrayRef<int64_t > strides, RewriterBase &rewriter) {
98+ // maps the iteration space of the im2col matrix to (output_y, filter_y)
99+ auto hIndicesMap = AffineMap::inferFromExprList (
100+ {ArrayRef{exprs.ohIndex , exprs.fhIndex }}, rewriter.getContext ())[0 ];
101+ // maps the iteration space of the im2col matrix to (output_x, filter_x)
102+ auto wIndicesMap = AffineMap::inferFromExprList (
103+ {ArrayRef{exprs.owIndex , exprs.fwIndex }}, rewriter.getContext ())[0 ];
104+ // Compute the input indexing map, to map the indices of the im2col matrix to
105+ // the original input offsets. Each element of the im2col matrix corresponds
106+ // to a pair of (out_element, filter_element). First, we build the expressions
107+ // to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs;
108+ // then we compose them with the maps that map the im2col matrix elements to
109+ // the (out_element, filter_element) pairs.
110+ auto bIndexExpr = rewriter.getAffineDimExpr (0U );
111+ auto hIndexExpr = getConvolvedExpr (rewriter, strides[0 ],
112+ /* useSymbols*/ false );
113+ hIndexExpr = hIndexExpr.compose (hIndicesMap);
114+ auto wIndexExpr = getConvolvedExpr (rewriter, strides[1 ],
115+ /* useSymbols*/ false );
116+ wIndexExpr = wIndexExpr.compose (wIndicesMap);
117+ auto cIndexExpr = exprs.icIndex ;
118+ return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
75119}
76120
77121FailureOr<std::pair<Operation *, Operation *>>
@@ -135,44 +179,37 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
135179 auto reduction = utils::IteratorType::reduction;
136180 SmallVector<utils::IteratorType> img2colIterators (nloops, parallel);
137181
182+ // Given an index of the im2col matrix, retrieve the corresponding indices of
183+ // the output and filter matrices
184+ auto mIndicesExprs =
185+ delinearize (rewriter.getAffineDimExpr (1U ), ArrayRef<int64_t >{ow, 1 });
186+ auto kIndicesExprs = delinearize (rewriter.getAffineDimExpr (2U ),
187+ ArrayRef<int64_t >{fw * ic, ic, 1 });
188+ Im2ColToOperandsExprs i2cToOperExprs;
189+ i2cToOperExprs.fhIndex = kIndicesExprs [0 ];
190+ i2cToOperExprs.fwIndex = kIndicesExprs [1 ];
191+ i2cToOperExprs.icIndex = kIndicesExprs [2 ];
192+ i2cToOperExprs.ohIndex = mIndicesExprs [0 ];
193+ i2cToOperExprs.owIndex = mIndicesExprs [1 ];
194+
195+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
196+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions (
197+ i2cToOperExprs, llvm::to_vector (convOp.getStrides ().getValues <int64_t >()),
198+ rewriter);
199+ auto inMap =
200+ AffineMap::inferFromExprList ({ArrayRef{inExprs.bIndex , inExprs.hIndex ,
201+ inExprs.wIndex , inExprs.cIndex }},
202+ rewriter.getContext ())[0 ];
203+
138204 SmallVector<AffineMap> img2colIndexingMaps = {
139- AffineMap::getMultiDimIdentityMap (nloops, context)};
205+ inMap, AffineMap::getMultiDimIdentityMap (nloops, context)};
140206
141207 auto img2ColTensor = linalg::GenericOp::create (
142208 rewriter, loc, colTensor.getType (),
143- /* inputs=*/ ValueRange{} , /* outputs=*/ colTensor, img2colIndexingMaps,
209+ /* inputs=*/ input , /* outputs=*/ colTensor, img2colIndexingMaps,
144210 img2colIterators,
145211 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
146- // Get the iterators named based on the matmul (batch, m, k).
147- Value bIndex = linalg::IndexOp::create (nestedBuilder, loc, 0 );
148- Value mIndex = linalg::IndexOp::create (nestedBuilder, loc, 1 );
149- Value kIndex = linalg::IndexOp::create (nestedBuilder, loc, 2 );
150-
151- // Recover the original iteration indices from the problem/input sizes.
152- SmallVector<Value> mIndices = unrollIndex (
153- nestedBuilder, nestedLoc, mIndex , ArrayRef<int64_t >{oh, ow});
154- auto ohIndex = mIndices [0 ];
155- auto owIndex = mIndices [1 ];
156-
157- SmallVector<Value> kIndices = unrollIndex (
158- nestedBuilder, nestedLoc, kIndex , ArrayRef<int64_t >{fh, fw, ic});
159- auto fhIndex = kIndices [0 ];
160- auto fwIndex = kIndices [1 ];
161- auto icIndex = kIndices [2 ];
162-
163- // Extract the input element corresponding to the expanded indices.
164- Value hIndex =
165- getConvolvedIndex (nestedBuilder, nestedLoc, ohIndex, fhIndex,
166- convOp.getStrides ().getValues <int64_t >()[0 ]);
167- Value wIndex =
168- getConvolvedIndex (nestedBuilder, nestedLoc, owIndex, fwIndex,
169- convOp.getStrides ().getValues <int64_t >()[1 ]);
170-
171- // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
172- SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
173- Value inputVal = tensor::ExtractOp::create (nestedBuilder, loc, input,
174- extractionIndices);
175- linalg::YieldOp::create (nestedBuilder, nestedLoc, inputVal);
212+ linalg::YieldOp::create (nestedBuilder, nestedLoc, args[0 ]);
176213 });
177214
178215 // Because the filter does not share the same batch dimension,
@@ -421,44 +458,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
421458 auto reduction = utils::IteratorType::reduction;
422459 SmallVector<utils::IteratorType, 3 > img2colIterators (nloops, parallel);
423460
424- SmallVector<AffineMap, 4 > img2colIndexingMaps = {
425- AffineMap::getMultiDimIdentityMap (nloops, context)};
461+ // Recover the original iteration indices from the problem/input sizes:
462+ // given an index of the im2col matrix, retrieve the corresponding indices of
463+ // the output and filter matrices
464+ auto kIndicesExprs = delinearize (rewriter.getAffineDimExpr (1U ),
465+ ArrayRef<int64_t >{fh * fw, fw, 1 });
466+ auto mIndicesExprs =
467+ delinearize (rewriter.getAffineDimExpr (2U ), ArrayRef<int64_t >{ow, 1 });
468+ Im2ColToOperandsExprs i2cToOperExprs;
469+ i2cToOperExprs.icIndex = kIndicesExprs [0 ];
470+ i2cToOperExprs.fhIndex = kIndicesExprs [1 ];
471+ i2cToOperExprs.fwIndex = kIndicesExprs [2 ];
472+ i2cToOperExprs.ohIndex = mIndicesExprs [0 ];
473+ i2cToOperExprs.owIndex = mIndicesExprs [1 ];
474+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions (
475+ i2cToOperExprs, llvm::to_vector (convOp.getStrides ().getValues <int64_t >()),
476+ rewriter);
477+ auto inMap =
478+ AffineMap::inferFromExprList ({ArrayRef{inExprs.bIndex , inExprs.cIndex ,
479+ inExprs.hIndex , inExprs.wIndex }},
480+ rewriter.getContext ())[0 ];
481+ // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
482+ SmallVector<AffineMap> img2colIndexingMaps = {
483+ inMap, AffineMap::getMultiDimIdentityMap (nloops, context)};
426484
427485 auto img2ColTensor = linalg::GenericOp::create (
428486 rewriter, loc, colTensor.getType (),
429- /* inputs=*/ ValueRange{} , /* outputs=*/ colTensor, img2colIndexingMaps,
487+ /* inputs=*/ input , /* outputs=*/ colTensor, img2colIndexingMaps,
430488 img2colIterators,
431489 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
432- // Get the iterators named based on the matmul (batch, m, k).
433- Value bIndex = linalg::IndexOp::create (nestedBuilder, loc, 0 );
434- Value kIndex = linalg::IndexOp::create (nestedBuilder, loc, 1 );
435- Value nIndex = linalg::IndexOp::create (nestedBuilder, loc, 2 );
436-
437- // Recover the original iteration indices from the problem/input sizes.
438- SmallVector<Value> kIndices = unrollIndex (
439- nestedBuilder, nestedLoc, kIndex , ArrayRef<int64_t >{ic, fh, fw});
440- auto icIndex = kIndices [0 ];
441- auto fhIndex = kIndices [1 ];
442- auto fwIndex = kIndices [2 ];
443-
444- SmallVector<Value> nIndices = unrollIndex (
445- nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t >{oh, ow});
446- auto ohIndex = nIndices[0 ];
447- auto owIndex = nIndices[1 ];
448-
449- // Extract the input element corresponding to the expanded indices.
450- Value hIndex =
451- getConvolvedIndex (nestedBuilder, nestedLoc, ohIndex, fhIndex,
452- convOp.getStrides ().getValues <int64_t >()[0 ]);
453- Value wIndex =
454- getConvolvedIndex (nestedBuilder, nestedLoc, owIndex, fwIndex,
455- convOp.getStrides ().getValues <int64_t >()[1 ]);
456-
457- // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
458- SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
459- Value inputVal = tensor::ExtractOp::create (nestedBuilder, loc, input,
460- extractionIndices);
461- linalg::YieldOp::create (nestedBuilder, nestedLoc, inputVal);
490+ linalg::YieldOp::create (nestedBuilder, nestedLoc, args[0 ]);
462491 });
463492
464493 // Because the filter does not share the same batch dimension,
@@ -545,6 +574,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
545574 Value reshapedOutput = tensor::CollapseShapeOp::create (
546575 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
547576
577+ // Shape of the Toeplitz matrix produced by Im2col.
548578 SmallVector<int64_t > colTensorShape = {n, oh * ow, fh * fw * ic};
549579 Value colTensor = tensor::EmptyOp::create (rewriter, loc, colTensorShape,
550580 inputType.getElementType ());
@@ -556,44 +586,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
556586 auto reduction = utils::IteratorType::reduction;
557587 SmallVector<utils::IteratorType> img2colIterators (nloops, parallel);
558588
589+ // Given an index of the im2col matrix, retrieve the corresponding indices of
590+ // the output and filter matrices
591+ auto mIndicesExprs =
592+ delinearize (rewriter.getAffineDimExpr (1U ), ArrayRef<int64_t >{ow, 1 });
593+ auto kIndicesExprs = delinearize (rewriter.getAffineDimExpr (2U ),
594+ ArrayRef<int64_t >{fw * ic, ic, 1 });
595+ Im2ColToOperandsExprs i2cToOperExprs;
596+ i2cToOperExprs.fhIndex = kIndicesExprs [0 ];
597+ i2cToOperExprs.fwIndex = kIndicesExprs [1 ];
598+ i2cToOperExprs.icIndex = kIndicesExprs [2 ];
599+ i2cToOperExprs.ohIndex = mIndicesExprs [0 ];
600+ i2cToOperExprs.owIndex = mIndicesExprs [1 ];
601+
602+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
603+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions (
604+ i2cToOperExprs, llvm::to_vector (convOp.getStrides ().getValues <int64_t >()),
605+ rewriter);
606+ auto inMap =
607+ AffineMap::inferFromExprList ({ArrayRef{inExprs.bIndex , inExprs.hIndex ,
608+ inExprs.wIndex , inExprs.cIndex }},
609+ rewriter.getContext ())[0 ];
559610 SmallVector<AffineMap> img2colIndexingMaps = {
560- AffineMap::getMultiDimIdentityMap (nloops, context)};
611+ inMap, AffineMap::getMultiDimIdentityMap (nloops, context)};
561612
562613 auto img2ColTensor = linalg::GenericOp::create (
563614 rewriter, loc, colTensor.getType (),
564- /* inputs=*/ ValueRange{} , /* outputs=*/ colTensor, img2colIndexingMaps,
615+ /* inputs=*/ input , /* outputs=*/ colTensor, img2colIndexingMaps,
565616 img2colIterators,
566617 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
567- // Get the iterators named based on the matmul (batch, m, k).
568- Value bIndex = linalg::IndexOp::create (nestedBuilder, loc, 0 );
569- Value mIndex = linalg::IndexOp::create (nestedBuilder, loc, 1 );
570- Value kIndex = linalg::IndexOp::create (nestedBuilder, loc, 2 );
571-
572- // Recover the original iteration indices from the problem/input sizes.
573- SmallVector<Value> mIndices = unrollIndex (
574- nestedBuilder, nestedLoc, mIndex , ArrayRef<int64_t >{oh, ow});
575- auto ohIndex = mIndices [0 ];
576- auto owIndex = mIndices [1 ];
577-
578- SmallVector<Value> kIndices = unrollIndex (
579- nestedBuilder, nestedLoc, kIndex , ArrayRef<int64_t >{fh, fw, ic});
580- auto fhIndex = kIndices [0 ];
581- auto fwIndex = kIndices [1 ];
582- auto icIndex = kIndices [2 ];
583-
584- // Extract the input element corresponding to the expanded indices.
585- Value hIndex =
586- getConvolvedIndex (nestedBuilder, nestedLoc, ohIndex, fhIndex,
587- convOp.getStrides ().getValues <int64_t >()[0 ]);
588- Value wIndex =
589- getConvolvedIndex (nestedBuilder, nestedLoc, owIndex, fwIndex,
590- convOp.getStrides ().getValues <int64_t >()[1 ]);
591-
592- // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
593- SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
594- Value inputVal = tensor::ExtractOp::create (nestedBuilder, loc, input,
595- extractionIndices);
596- linalg::YieldOp::create (nestedBuilder, nestedLoc, inputVal);
618+ linalg::YieldOp::create (nestedBuilder, nestedLoc, args[0 ]);
597619 });
598620
599621 // Because we didn't transpose the filters we don't actually have a batched
0 commit comments