Skip to content

Commit 5d51e00

Browse files
[mlir][linalg] Propagate filter tensor encoding in im2col (#160099)
In the im2col decomposition, propagate the filter tensor encoding (if specified) through the tensor.collapse_shape op, so that it can be used by the consuming linalg.generic matmul op. Signed-off-by: Fabrizio Indirli <[email protected]>
1 parent 745e1e6 commit 5d51e00

File tree

3 files changed

+107
-4
lines changed

3 files changed

+107
-4
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,6 +1858,7 @@ void populateDecomposePadPatterns(RewritePatternSet &patterns);
18581858

18591859
/// Populates patterns to transform linalg.conv_2d_xxx operations into
18601860
/// linalg.generic (for img2col packing) and linalg.matmul.
1861+
/// Note: currently limited to Tensor semantics only.
18611862
/// \see rewriteInIm2Col for more details.
18621863
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
18631864

mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/IR/Builders.h"
2121
#include "mlir/IR/BuiltinAttributes.h"
2222
#include "mlir/IR/BuiltinTypes.h"
23+
#include <cassert>
2324
#include <utility>
2425

2526
namespace mlir {
@@ -124,6 +125,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
124125
auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
125126
auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
126127

128+
if (!convOp.hasPureTensorSemantics())
129+
return rewriter.notifyMatchFailure(
130+
convOp, "expected op to have pure tensor semantics");
131+
127132
if (!filterType.hasStaticShape())
128133
return rewriter.notifyMatchFailure(
129134
convOp, "expected a static shape for the filter");
@@ -155,10 +160,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
155160

156161
Location loc = convOp.getLoc();
157162

163+
assert(isa<RankedTensorType>(filterType) &&
164+
"expected filter type to be a ranked tensor");
165+
auto tensorFilterType = cast<RankedTensorType>(filterType);
166+
158167
// Reshape output and filter to the LHS and result of a (B)MNK matmul.
159168
SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
160169
auto reshapedFilterType =
161-
RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
170+
RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(),
171+
tensorFilterType.getEncoding());
162172
Value reshapedFilter = tensor::CollapseShapeOp::create(
163173
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
164174

@@ -253,6 +263,10 @@ rewriteInIm2Col(RewriterBase &rewriter,
253263
auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
254264
auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
255265

266+
if (!convOp.hasPureTensorSemantics())
267+
return rewriter.notifyMatchFailure(
268+
convOp, "expected op to have pure tensor semantics");
269+
256270
if (!filterType.hasStaticShape())
257271
return rewriter.notifyMatchFailure(
258272
convOp, "expected a static shape for the filter");
@@ -404,6 +418,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
404418
auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
405419
auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
406420

421+
if (!convOp.hasPureTensorSemantics())
422+
return rewriter.notifyMatchFailure(
423+
convOp, "expected op to have pure tensor semantics");
424+
407425
if (!filterType.hasStaticShape())
408426
return rewriter.notifyMatchFailure(
409427
convOp, "expected a static shape for the filter");
@@ -435,9 +453,14 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
435453
auto loc = convOp.getLoc();
436454
MLIRContext *context = rewriter.getContext();
437455

456+
assert(isa<RankedTensorType>(filterType) &&
457+
"expected filter type to be a ranked tensor");
458+
auto tensorFilterType = cast<RankedTensorType>(filterType);
459+
438460
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
439461
auto reshapedFilterType =
440-
RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
462+
RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(),
463+
tensorFilterType.getEncoding());
441464
Value reshapedFilter = tensor::CollapseShapeOp::create(
442465
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
443466

@@ -529,6 +552,10 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
529552
auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
530553
auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
531554

555+
if (!convOp.hasPureTensorSemantics())
556+
return rewriter.notifyMatchFailure(
557+
convOp, "expected op to have pure tensor semantics");
558+
532559
if (!filterType.hasStaticShape())
533560
return rewriter.notifyMatchFailure(
534561
convOp, "expected a static shape for the filter");
@@ -560,11 +587,16 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
560587

561588
Location loc = convOp.getLoc();
562589

590+
assert(isa<RankedTensorType>(filterType) &&
591+
"expected filter type to be a ranked tensor");
592+
auto tensorFilterType = cast<RankedTensorType>(filterType);
593+
563594
// Reshape output and filter to the LHS and result of a "row-wise" matrix
564595
// multiplication.
565596
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
566597
auto reshapedFilterType =
567-
RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
598+
RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(),
599+
tensorFilterType.getEncoding());
568600
Value reshapedFilter = tensor::CollapseShapeOp::create(
569601
rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
570602

mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,26 @@ module attributes {transform.with_named_sequence} {
2626

2727
// -----
2828

29+
// Memref semantics is not supported.
30+
// Check that we emit an error.
31+
func.func @negative_conv_memref(%arg0: memref<1x16x16x4xf32>, %arg1: memref<16x3x3x4xf32>, %arg2: memref<1x14x14x16xf32>) {
32+
// expected-note@below {{when applied to this op}}
33+
linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : memref<2xi64>, strides = dense<1> : memref<2xi64> }
34+
ins(%arg0, %arg1: memref<1x16x16x4xf32>, memref<16x3x3x4xf32>) outs(%arg2: memref<1x14x14x16xf32>)
35+
return
36+
}
37+
38+
module attributes {transform.with_named_sequence} {
39+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
40+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
41+
// expected-error@below {{failed to apply}}
42+
%img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
43+
transform.yield
44+
}
45+
}
46+
47+
// -----
48+
2949
// Check that we get the proper handles for the img2col tensor producer
3050
// and the final instruction.
3151

@@ -267,6 +287,31 @@ module attributes {transform.with_named_sequence} {
267287

268288
// -----
269289

290+
// Check that the encoding on the filter (weights) tensor is propagated when applying the transform.
291+
292+
// CHECK: func.func @batch_nchw_conv_with_filter_encoding(%[[INPUT:.+]]: tensor<8x4x16x16xf32>, %[[FILTER:.*]]: tensor<16x4x3x3xf32, 42 : i32>, %[[OUTPUT:.*]]: tensor<8x16x14x14xf32>)
293+
// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
294+
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x4x3x3xf32, 42 : i32> into tensor<16x36xf32, 42 : i32>
295+
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<8x4x16x16xf32>)
296+
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {{.*}} ins(%[[COLLAPSED_FILTER]], %[[COL_TENSOR]] : tensor<16x36xf32, 42 : i32>, tensor<8x36x196xf32>)
297+
func.func @batch_nchw_conv_with_filter_encoding(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32, 42 : i32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> {
298+
%0 = linalg.conv_2d_nchw_fchw
299+
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
300+
ins(%arg0, %arg1: tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32, 42 : i32>)
301+
outs(%arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
302+
return %0 : tensor<8x16x14x14xf32>
303+
}
304+
305+
module attributes {transform.with_named_sequence} {
306+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
307+
%0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
308+
%1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
309+
transform.yield
310+
}
311+
}
312+
313+
// -----
314+
270315
// CHECK: IR printer: tensor_producer
271316
// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
272317
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
@@ -290,7 +335,7 @@ module attributes {transform.with_named_sequence} {
290335
// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
291336
// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
292337
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
293-
// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
338+
// CHECK-SAME: [#[[MAP0]], #[[MAP1]]], {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>) outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
294339
// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
295340
// CHECK: linalg.yield %{{.+}} : f32
296341
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
@@ -327,6 +372,31 @@ module attributes {transform.with_named_sequence} {
327372

328373
// -----
329374

375+
// Check that the encoding on the filter (weights) tensor is propagated when applying the transform.
376+
377+
// CHECK: func.func @conv_2d_nhwc_fhwc_with_filter_encoding(%[[INPUT:.+]]: tensor<1x16x16x4xf32>, %[[FILTER:.*]]: tensor<16x3x3x4xf32, 42 : i32>, %[[OUTPUT:.*]]: tensor<1x14x14x16xf32>)
378+
// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]]
379+
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] : tensor<16x3x3x4xf32, 42 : i32> into tensor<16x36xf32, 42 : i32>
380+
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic {{.*}} ins(%[[INPUT]] : tensor<1x16x16x4xf32>)
381+
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {{.*}} ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32, 42 : i32>)
382+
func.func @conv_2d_nhwc_fhwc_with_filter_encoding(%input: tensor<1x16x16x4xf32>, %filter: tensor<16x3x3x4xf32, 42 : i32>, %out: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
383+
%0 = linalg.conv_2d_nhwc_fhwc
384+
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
385+
ins(%input, %filter: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32, 42 : i32>)
386+
outs(%out: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
387+
return %0 : tensor<1x14x14x16xf32>
388+
}
389+
390+
module attributes {transform.with_named_sequence} {
391+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
392+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
393+
%1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
394+
transform.yield
395+
}
396+
}
397+
398+
// -----
399+
330400
// Check for signed extend when the input type is smaller than the accumulator type.
331401

332402
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

0 commit comments

Comments
 (0)