12
12
#include " mlir/Dialect/Linalg/IR/Linalg.h"
13
13
#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
14
14
#include " mlir/Dialect/Tensor/IR/Tensor.h"
15
+ #include " mlir/Dialect/Utils/IndexingUtils.h"
15
16
#include " mlir/Dialect/Utils/ReshapeOpsUtils.h"
16
17
#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
17
18
#include " mlir/IR/AffineExpr.h"
@@ -50,28 +51,71 @@ static Value createMul(Location loc, Value x, Value y, Type accType,
50
51
return arith::MulFOp::create (builder, loc, xConvert, yConvert);
51
52
}
52
53
53
- // Delinearizes the given composite `index` by the basis specified in `factors`.
54
- static SmallVector<Value> unrollIndex (OpBuilder &b, Location loc, Value index ,
55
- ArrayRef< int64_t > factors) {
56
- assert (!factors. empty () && " empty factor list " );
57
- SmallVector<Value> basis;
58
- for ( int64_t f : factors)
59
- basis. push_back ( arith::ConstantOp::create (b, loc, b. getIndexAttr (f)));
60
- FailureOr<SmallVector<Value>> multiIndex =
61
- affine::delinearizeIndex (b, loc, index, basis);
62
- assert (! failed (multiIndex) && " Failed to linearize img2col index " );
63
- return *multiIndex ;
54
+ // Generate the affine expression to compute the convolved index
55
+ // for the input as `oIndex * stride + fIndex` ,
56
+ // where oIndex: output iterator; fIndex: filter iterator.
57
+ static AffineExpr getConvolvedExpr (OpBuilder &b, int64_t stride,
58
+ bool useSymbols = true ) {
59
+ AffineExpr oExpr, fExpr ;
60
+ if (useSymbols)
61
+ bindSymbols (b. getContext (), oExpr, fExpr );
62
+ else
63
+ bindDims (b. getContext (), oExpr, fExpr );
64
+ return AffineExpr (stride * oExpr + fExpr ) ;
64
65
}
65
66
66
- // Given indices corresponding to iterators in the output (oIndex) and filter
67
- // (fIndex) for a convolution, compute the convolved index for the
68
- // input as `oIndex * stride + fIndex`.
69
- static Value getConvolvedIndex (OpBuilder &b, Location loc, Value oIndex,
70
- Value fIndex , int64_t stride) {
71
- AffineExpr oExpr, fExpr ;
72
- bindSymbols (b.getContext (), oExpr, fExpr );
73
- AffineMap convMap = AffineMap::get (0 , 2 , stride * oExpr + fExpr );
74
- return affine::makeComposedAffineApply (b, loc, convMap, {oIndex, fIndex });
67
+ // Stores the affine expressions to map the iteration space of the im2col matrix
68
+ // to the corresponding indices of the output and filter matrices
69
+ struct Im2ColToOperandsExprs {
70
+ AffineExpr fhIndex;
71
+ AffineExpr fwIndex;
72
+ AffineExpr icIndex;
73
+ AffineExpr ohIndex;
74
+ AffineExpr owIndex;
75
+ };
76
+
77
+ // Stores the affine expressions to map the iteration space of the im2col matrix
78
+ // to the input matrix indices
79
+ struct Im2ColToInputDimsExprs {
80
+ AffineExpr bIndex;
81
+ AffineExpr hIndex;
82
+ AffineExpr wIndex;
83
+ AffineExpr cIndex;
84
+ };
85
+
86
+ // / Construct the affine expressions that map the indices of the im2col matrix
87
+ // / to the corresponding input tensor indices for a 2D convolution with the the
88
+ // / provided strides.
89
+ // /
90
+ // / @param exprs Affine expressions for output and filter indices.
91
+ // / @param strides [height, width] stride values for the convolution.
92
+ // / @param rewriter Pattern rewriter.
93
+ // / @return Affine expressions mapping im2col matrix indices to input
94
+ // / offsets.
95
+ static Im2ColToInputDimsExprs
96
+ getIm2ColInputExpressions (Im2ColToOperandsExprs exprs,
97
+ ArrayRef<int64_t > strides, RewriterBase &rewriter) {
98
+ // maps the iteration space of the im2col matrix to (output_y, filter_y)
99
+ auto hIndicesMap = AffineMap::inferFromExprList (
100
+ {ArrayRef{exprs.ohIndex , exprs.fhIndex }}, rewriter.getContext ())[0 ];
101
+ // maps the iteration space of the im2col matrix to (output_x, filter_x)
102
+ auto wIndicesMap = AffineMap::inferFromExprList (
103
+ {ArrayRef{exprs.owIndex , exprs.fwIndex }}, rewriter.getContext ())[0 ];
104
+ // Compute the input indexing map, to map the indices of the im2col matrix to
105
+ // the original input offsets. Each element of the im2col matrix corresponds
106
+ // to a pair of (out_element, filter_element). First, we build the expressions
107
+ // to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs;
108
+ // then we compose them with the maps that map the im2col matrix elements to
109
+ // the (out_element, filter_element) pairs.
110
+ auto bIndexExpr = rewriter.getAffineDimExpr (0U );
111
+ auto hIndexExpr = getConvolvedExpr (rewriter, strides[0 ],
112
+ /* useSymbols*/ false );
113
+ hIndexExpr = hIndexExpr.compose (hIndicesMap);
114
+ auto wIndexExpr = getConvolvedExpr (rewriter, strides[1 ],
115
+ /* useSymbols*/ false );
116
+ wIndexExpr = wIndexExpr.compose (wIndicesMap);
117
+ auto cIndexExpr = exprs.icIndex ;
118
+ return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
75
119
}
76
120
77
121
FailureOr<std::pair<Operation *, Operation *>>
@@ -135,44 +179,37 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
135
179
auto reduction = utils::IteratorType::reduction;
136
180
SmallVector<utils::IteratorType> img2colIterators (nloops, parallel);
137
181
182
+ // Given an index of the im2col matrix, retrieve the corresponding indices of
183
+ // the output and filter matrices
184
+ auto mIndicesExprs =
185
+ delinearize (rewriter.getAffineDimExpr (1U ), ArrayRef<int64_t >{ow, 1 });
186
+ auto kIndicesExprs = delinearize (rewriter.getAffineDimExpr (2U ),
187
+ ArrayRef<int64_t >{fw * ic, ic, 1 });
188
+ Im2ColToOperandsExprs i2cToOperExprs;
189
+ i2cToOperExprs.fhIndex = kIndicesExprs [0 ];
190
+ i2cToOperExprs.fwIndex = kIndicesExprs [1 ];
191
+ i2cToOperExprs.icIndex = kIndicesExprs [2 ];
192
+ i2cToOperExprs.ohIndex = mIndicesExprs [0 ];
193
+ i2cToOperExprs.owIndex = mIndicesExprs [1 ];
194
+
195
+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
196
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions (
197
+ i2cToOperExprs, llvm::to_vector (convOp.getStrides ().getValues <int64_t >()),
198
+ rewriter);
199
+ auto inMap =
200
+ AffineMap::inferFromExprList ({ArrayRef{inExprs.bIndex , inExprs.hIndex ,
201
+ inExprs.wIndex , inExprs.cIndex }},
202
+ rewriter.getContext ())[0 ];
203
+
138
204
SmallVector<AffineMap> img2colIndexingMaps = {
139
- AffineMap::getMultiDimIdentityMap (nloops, context)};
205
+ inMap, AffineMap::getMultiDimIdentityMap (nloops, context)};
140
206
141
207
auto img2ColTensor = linalg::GenericOp::create (
142
208
rewriter, loc, colTensor.getType (),
143
- /* inputs=*/ ValueRange{} , /* outputs=*/ colTensor, img2colIndexingMaps,
209
+ /* inputs=*/ input , /* outputs=*/ colTensor, img2colIndexingMaps,
144
210
img2colIterators,
145
211
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
146
- // Get the iterators named based on the matmul (batch, m, k).
147
- Value bIndex = linalg::IndexOp::create (nestedBuilder, loc, 0 );
148
- Value mIndex = linalg::IndexOp::create (nestedBuilder, loc, 1 );
149
- Value kIndex = linalg::IndexOp::create (nestedBuilder, loc, 2 );
150
-
151
- // Recover the original iteration indices from the problem/input sizes.
152
- SmallVector<Value> mIndices = unrollIndex (
153
- nestedBuilder, nestedLoc, mIndex , ArrayRef<int64_t >{oh, ow});
154
- auto ohIndex = mIndices [0 ];
155
- auto owIndex = mIndices [1 ];
156
-
157
- SmallVector<Value> kIndices = unrollIndex (
158
- nestedBuilder, nestedLoc, kIndex , ArrayRef<int64_t >{fh, fw, ic});
159
- auto fhIndex = kIndices [0 ];
160
- auto fwIndex = kIndices [1 ];
161
- auto icIndex = kIndices [2 ];
162
-
163
- // Extract the input element corresponding to the expanded indices.
164
- Value hIndex =
165
- getConvolvedIndex (nestedBuilder, nestedLoc, ohIndex, fhIndex,
166
- convOp.getStrides ().getValues <int64_t >()[0 ]);
167
- Value wIndex =
168
- getConvolvedIndex (nestedBuilder, nestedLoc, owIndex, fwIndex,
169
- convOp.getStrides ().getValues <int64_t >()[1 ]);
170
-
171
- // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
172
- SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
173
- Value inputVal = tensor::ExtractOp::create (nestedBuilder, loc, input,
174
- extractionIndices);
175
- linalg::YieldOp::create (nestedBuilder, nestedLoc, inputVal);
212
+ linalg::YieldOp::create (nestedBuilder, nestedLoc, args[0 ]);
176
213
});
177
214
178
215
// Because the filter does not share the same batch dimension,
@@ -421,44 +458,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
421
458
auto reduction = utils::IteratorType::reduction;
422
459
SmallVector<utils::IteratorType, 3 > img2colIterators (nloops, parallel);
423
460
424
- SmallVector<AffineMap, 4 > img2colIndexingMaps = {
425
- AffineMap::getMultiDimIdentityMap (nloops, context)};
461
+ // Recover the original iteration indices from the problem/input sizes:
462
+ // given an index of the im2col matrix, retrieve the corresponding indices of
463
+ // the output and filter matrices
464
+ auto kIndicesExprs = delinearize (rewriter.getAffineDimExpr (1U ),
465
+ ArrayRef<int64_t >{fh * fw, fw, 1 });
466
+ auto mIndicesExprs =
467
+ delinearize (rewriter.getAffineDimExpr (2U ), ArrayRef<int64_t >{ow, 1 });
468
+ Im2ColToOperandsExprs i2cToOperExprs;
469
+ i2cToOperExprs.icIndex = kIndicesExprs [0 ];
470
+ i2cToOperExprs.fhIndex = kIndicesExprs [1 ];
471
+ i2cToOperExprs.fwIndex = kIndicesExprs [2 ];
472
+ i2cToOperExprs.ohIndex = mIndicesExprs [0 ];
473
+ i2cToOperExprs.owIndex = mIndicesExprs [1 ];
474
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions (
475
+ i2cToOperExprs, llvm::to_vector (convOp.getStrides ().getValues <int64_t >()),
476
+ rewriter);
477
+ auto inMap =
478
+ AffineMap::inferFromExprList ({ArrayRef{inExprs.bIndex , inExprs.cIndex ,
479
+ inExprs.hIndex , inExprs.wIndex }},
480
+ rewriter.getContext ())[0 ];
481
+ // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
482
+ SmallVector<AffineMap> img2colIndexingMaps = {
483
+ inMap, AffineMap::getMultiDimIdentityMap (nloops, context)};
426
484
427
485
auto img2ColTensor = linalg::GenericOp::create (
428
486
rewriter, loc, colTensor.getType (),
429
- /* inputs=*/ ValueRange{} , /* outputs=*/ colTensor, img2colIndexingMaps,
487
+ /* inputs=*/ input , /* outputs=*/ colTensor, img2colIndexingMaps,
430
488
img2colIterators,
431
489
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
432
- // Get the iterators named based on the matmul (batch, m, k).
433
- Value bIndex = linalg::IndexOp::create (nestedBuilder, loc, 0 );
434
- Value kIndex = linalg::IndexOp::create (nestedBuilder, loc, 1 );
435
- Value nIndex = linalg::IndexOp::create (nestedBuilder, loc, 2 );
436
-
437
- // Recover the original iteration indices from the problem/input sizes.
438
- SmallVector<Value> kIndices = unrollIndex (
439
- nestedBuilder, nestedLoc, kIndex , ArrayRef<int64_t >{ic, fh, fw});
440
- auto icIndex = kIndices [0 ];
441
- auto fhIndex = kIndices [1 ];
442
- auto fwIndex = kIndices [2 ];
443
-
444
- SmallVector<Value> nIndices = unrollIndex (
445
- nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t >{oh, ow});
446
- auto ohIndex = nIndices[0 ];
447
- auto owIndex = nIndices[1 ];
448
-
449
- // Extract the input element corresponding to the expanded indices.
450
- Value hIndex =
451
- getConvolvedIndex (nestedBuilder, nestedLoc, ohIndex, fhIndex,
452
- convOp.getStrides ().getValues <int64_t >()[0 ]);
453
- Value wIndex =
454
- getConvolvedIndex (nestedBuilder, nestedLoc, owIndex, fwIndex,
455
- convOp.getStrides ().getValues <int64_t >()[1 ]);
456
-
457
- // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
458
- SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
459
- Value inputVal = tensor::ExtractOp::create (nestedBuilder, loc, input,
460
- extractionIndices);
461
- linalg::YieldOp::create (nestedBuilder, nestedLoc, inputVal);
490
+ linalg::YieldOp::create (nestedBuilder, nestedLoc, args[0 ]);
462
491
});
463
492
464
493
// Because the filter does not share the same batch dimension,
@@ -545,6 +574,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
545
574
Value reshapedOutput = tensor::CollapseShapeOp::create (
546
575
rewriter, loc, reshapedOutputType, output, outputReassocIndices);
547
576
577
+ // Shape of the Toeplitz matrix produced by Im2col.
548
578
SmallVector<int64_t > colTensorShape = {n, oh * ow, fh * fw * ic};
549
579
Value colTensor = tensor::EmptyOp::create (rewriter, loc, colTensorShape,
550
580
inputType.getElementType ());
@@ -556,44 +586,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
556
586
auto reduction = utils::IteratorType::reduction;
557
587
SmallVector<utils::IteratorType> img2colIterators (nloops, parallel);
558
588
589
+ // Given an index of the im2col matrix, retrieve the corresponding indices of
590
+ // the output and filter matrices
591
+ auto mIndicesExprs =
592
+ delinearize (rewriter.getAffineDimExpr (1U ), ArrayRef<int64_t >{ow, 1 });
593
+ auto kIndicesExprs = delinearize (rewriter.getAffineDimExpr (2U ),
594
+ ArrayRef<int64_t >{fw * ic, ic, 1 });
595
+ Im2ColToOperandsExprs i2cToOperExprs;
596
+ i2cToOperExprs.fhIndex = kIndicesExprs [0 ];
597
+ i2cToOperExprs.fwIndex = kIndicesExprs [1 ];
598
+ i2cToOperExprs.icIndex = kIndicesExprs [2 ];
599
+ i2cToOperExprs.ohIndex = mIndicesExprs [0 ];
600
+ i2cToOperExprs.owIndex = mIndicesExprs [1 ];
601
+
602
+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
603
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions (
604
+ i2cToOperExprs, llvm::to_vector (convOp.getStrides ().getValues <int64_t >()),
605
+ rewriter);
606
+ auto inMap =
607
+ AffineMap::inferFromExprList ({ArrayRef{inExprs.bIndex , inExprs.hIndex ,
608
+ inExprs.wIndex , inExprs.cIndex }},
609
+ rewriter.getContext ())[0 ];
559
610
SmallVector<AffineMap> img2colIndexingMaps = {
560
- AffineMap::getMultiDimIdentityMap (nloops, context)};
611
+ inMap, AffineMap::getMultiDimIdentityMap (nloops, context)};
561
612
562
613
auto img2ColTensor = linalg::GenericOp::create (
563
614
rewriter, loc, colTensor.getType (),
564
- /* inputs=*/ ValueRange{} , /* outputs=*/ colTensor, img2colIndexingMaps,
615
+ /* inputs=*/ input , /* outputs=*/ colTensor, img2colIndexingMaps,
565
616
img2colIterators,
566
617
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
567
- // Get the iterators named based on the matmul (batch, m, k).
568
- Value bIndex = linalg::IndexOp::create (nestedBuilder, loc, 0 );
569
- Value mIndex = linalg::IndexOp::create (nestedBuilder, loc, 1 );
570
- Value kIndex = linalg::IndexOp::create (nestedBuilder, loc, 2 );
571
-
572
- // Recover the original iteration indices from the problem/input sizes.
573
- SmallVector<Value> mIndices = unrollIndex (
574
- nestedBuilder, nestedLoc, mIndex , ArrayRef<int64_t >{oh, ow});
575
- auto ohIndex = mIndices [0 ];
576
- auto owIndex = mIndices [1 ];
577
-
578
- SmallVector<Value> kIndices = unrollIndex (
579
- nestedBuilder, nestedLoc, kIndex , ArrayRef<int64_t >{fh, fw, ic});
580
- auto fhIndex = kIndices [0 ];
581
- auto fwIndex = kIndices [1 ];
582
- auto icIndex = kIndices [2 ];
583
-
584
- // Extract the input element corresponding to the expanded indices.
585
- Value hIndex =
586
- getConvolvedIndex (nestedBuilder, nestedLoc, ohIndex, fhIndex,
587
- convOp.getStrides ().getValues <int64_t >()[0 ]);
588
- Value wIndex =
589
- getConvolvedIndex (nestedBuilder, nestedLoc, owIndex, fwIndex,
590
- convOp.getStrides ().getValues <int64_t >()[1 ]);
591
-
592
- // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
593
- SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
594
- Value inputVal = tensor::ExtractOp::create (nestedBuilder, loc, input,
595
- extractionIndices);
596
- linalg::YieldOp::create (nestedBuilder, nestedLoc, inputVal);
618
+ linalg::YieldOp::create (nestedBuilder, nestedLoc, args[0 ]);
597
619
});
598
620
599
621
// Because we didn't transpose the filters we don't actually have a batched
0 commit comments