Skip to content

Commit 7087972

Browse files
authored
[GlobalOpt] Generalize 1x1 group convolutions (#20480)
This allows 1x1 group convolutions to be generalized, since they are effectively loops around matmuls. On llvmgpu this allows them to go down the contraction/matmul lowering path.
1 parent 5b55234 commit 7087972

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1818
#include "mlir/Pass/Pass.h"
1919

20+
#define DEBUG_TYPE "iree-global-opt-generalize-linalg-named-ops"
21+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
22+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
23+
2024
namespace mlir::iree_compiler::GlobalOptimization {
2125

2226
#define GEN_PASS_DEF_GENERALIZELINALGNAMEDOPSPASS
@@ -41,17 +45,15 @@ static bool isConvFoldableToContraction(linalg::LinalgOp linalgOp) {
4145

4246
if (!llvm::all_of(convDims.strides,
4347
[](int64_t element) { return element == 1; })) {
48+
LDBG("conv not foldable: non-unit strides");
4449
return false;
4550
}
4651

47-
// Dont generalize depthwise convolutions.
48-
if (!convDims.depth.empty()) {
49-
return false;
50-
}
51-
52-
// Dont generalize pooling operations. For pooling ops, the input/output
53-
// channel size will be categorized as the additional batch dimension
52+
// Dont generalize pooling operations or depthwise convolutions. For pooling
53+
// ops, the input/output channel size will be categorized as the additional
54+
// batch dimension.
5455
if (convDims.outputChannel.empty() || convDims.inputChannel.empty()) {
56+
LDBG("conv not foldable: missing input or output channel dims");
5557
return false;
5658
}
5759

@@ -60,6 +62,7 @@ static bool isConvFoldableToContraction(linalg::LinalgOp linalgOp) {
6062
auto filterShapeType = llvm::dyn_cast<RankedTensorType>(
6163
linalgOp.getDpsInputOperand(kFilterInputIdx)->get().getType());
6264
if (!filterShapeType) {
65+
LDBG("conv not foldable: filter shape not ranked tensor");
6366
return false;
6467
}
6568
auto filterShape = filterShapeType.getShape();
@@ -68,6 +71,7 @@ static bool isConvFoldableToContraction(linalg::LinalgOp linalgOp) {
6871
std::optional<int64_t> maybeDim = filterMap.getResultPosition(
6972
getAffineDimExpr(filterLoop, filterMap.getContext()));
7073
if (!maybeDim || filterShape[*maybeDim] != 1) {
74+
LDBG("conv not foldable: non-unit filter dim");
7175
return false;
7276
}
7377
}

compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,21 @@ util.func public @generalize_1x1_conv_2d_dilations(%input: tensor<1x4x?x2xf32>,
101101

102102
// -----
103103

104+
util.func public @generalize_1x1_group_conv_2d(%input: tensor<1x2x3x4x5xf32>, %filter: tensor<4x6x1x1x5xf32>) -> tensor<1x2x3x4x6xf32> {
105+
%0 = tensor.empty() : tensor<1x2x3x4x6xf32>
106+
%1 = linalg.conv_2d_nhwgc_gfhwc {
107+
dilations = dense<1> : tensor<2xi64>,
108+
strides = dense<1> : tensor<2xi64>
109+
} ins(%input, %filter : tensor<1x2x3x4x5xf32>, tensor<4x6x1x1x5xf32>) outs(%0 : tensor<1x2x3x4x6xf32>) -> tensor<1x2x3x4x6xf32>
110+
util.return %1 : tensor<1x2x3x4x6xf32>
111+
}
112+
113+
// CHECK-LABEL: @generalize_1x1_group_conv_2d
114+
// CHECK: %[[RESULT:.*]] = linalg.generic
115+
// CHECK: util.return %[[RESULT]]
116+
117+
// -----
118+
104119
util.func public @no_generalize_1x1_conv_2d_strides(%input: tensor<1x7x7x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x4x4x7xf32> {
105120
%0 = tensor.empty() : tensor<1x4x4x7xf32>
106121
%1 = linalg.conv_2d_nhwc_hwcf {
@@ -113,3 +128,18 @@ util.func public @no_generalize_1x1_conv_2d_strides(%input: tensor<1x7x7x2xf32>,
113128
// CHECK-LABEL: @no_generalize_1x1_conv_2d_strides
114129
// CHECK-NOT: linalg.generic
115130
// CHECK: util.return
131+
132+
// -----
133+
134+
util.func public @no_generalize_1x1_depthwise_conv(%input: tensor<1x2x3x4xf32>, %filter: tensor<1x1x4xf32>) -> tensor<1x2x3x4xf32> {
135+
%0 = tensor.empty() : tensor<1x2x3x4xf32>
136+
%1 = linalg.depthwise_conv_2d_nhwc_hwc {
137+
dilations = dense<1> : tensor<2xi64>,
138+
strides = dense<1> : tensor<2xi64>
139+
} ins(%input, %filter : tensor<1x2x3x4xf32>, tensor<1x1x4xf32>) outs(%0 : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
140+
util.return %1 : tensor<1x2x3x4xf32>
141+
}
142+
143+
// CHECK-LABEL: @no_generalize_1x1_depthwise_conv
144+
// CHECK-NOT: linalg.generic
145+
// CHECK: util.return

0 commit comments

Comments
 (0)