Skip to content

Commit c9aa55d

Browse files
authored
[mlir][Linalg] Add speculation for LinalgStructuredOps (llvm#108032)
This patch adds speculation behavior for linalg structured ops, allowing them to be hoisted out of loops using LICM.
1 parent 34cab2e commit c9aa55d

File tree

4 files changed

+127
-1
lines changed

4 files changed

+127
-1
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class LinalgStructuredBase_Op<string mnemonic, list<Trait> props>
2929
: Op<Linalg_Dialect, mnemonic, !listconcat([
3030
SingleBlockImplicitTerminator<"YieldOp">,
3131
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
32+
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
3233
DestinationStyleOpInterface,
3334
LinalgStructuredInterface,
3435
ReifyRankedShapedTypeOpInterface], props)> {

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "mlir/IR/OperationSupport.h"
3535
#include "mlir/IR/PatternMatch.h"
3636
#include "mlir/Interfaces/InferTypeOpInterface.h"
37+
#include "mlir/Interfaces/SideEffectInterfaces.h"
3738

3839
#include "llvm/ADT/DenseMap.h"
3940
#include "llvm/ADT/SmallSet.h"
@@ -1202,6 +1203,20 @@ void GenericOp::getEffects(
12021203
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
12031204
}
12041205

1206+
static Speculation::Speculatability
1207+
getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1208+
// Operands with value semantics are speculatable, while operands with memory
1209+
// semantics are not.
1210+
if (!linalgOp.hasPureTensorSemantics())
1211+
return Speculation::NotSpeculatable;
1212+
// The body of the op can still have speculation in its region.
1213+
return Speculation::RecursivelySpeculatable;
1214+
}
1215+
1216+
Speculation::Speculatability GenericOp::getSpeculatability() {
1217+
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1218+
}
1219+
12051220
LogicalResult GenericOp::verify() { return success(); }
12061221

12071222
namespace {
@@ -1553,6 +1568,10 @@ void MapOp::getEffects(
15531568
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
15541569
}
15551570

1571+
Speculation::Speculatability MapOp::getSpeculatability() {
1572+
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1573+
}
1574+
15561575
//===----------------------------------------------------------------------===//
15571576
// ReduceOp
15581577
//===----------------------------------------------------------------------===//
@@ -1621,6 +1640,10 @@ void ReduceOp::getEffects(
16211640
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
16221641
}
16231642

1643+
Speculation::Speculatability ReduceOp::getSpeculatability() {
1644+
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1645+
}
1646+
16241647
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
16251648
NamedAttrList &attributes,
16261649
StringRef attributeName) {
@@ -1906,6 +1929,10 @@ void TransposeOp::getEffects(
19061929
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
19071930
}
19081931

1932+
Speculation::Speculatability TransposeOp::getSpeculatability() {
1933+
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1934+
}
1935+
19091936
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
19101937
SmallVectorImpl<OpFoldResult> &result) {
19111938
// Only the tensor type is supported.
@@ -2134,6 +2161,10 @@ void BroadcastOp::getEffects(
21342161
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
21352162
}
21362163

2164+
Speculation::Speculatability BroadcastOp::getSpeculatability() {
2165+
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2166+
}
2167+
21372168
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
21382169
MLIRContext *context) {
21392170
results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);

mlir/test/Transforms/loop-invariant-code-motion.mlir

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,3 +1118,94 @@ func.func @hoist_from_scf_while(%arg0: i32, %arg1: i32) -> i32 {
11181118
}
11191119
return %0 : i32
11201120
}
1121+
1122+
// -----
1123+
1124+
#trait = {
1125+
indexing_maps = [
1126+
affine_map<(m, n, k) -> (m, k)>,
1127+
affine_map<(m, n, k) -> (k, n)>,
1128+
affine_map<(m, n, k) -> (m, n)>
1129+
],
1130+
iterator_types = ["parallel", "parallel", "reduction"]
1131+
}
1132+
1133+
// CHECK-LABEL: func @hoist_linalg_ops
1134+
// CHECK: linalg.generic
1135+
// CHECK: scf.for
1136+
// CHECK-NOT: linalg.generic
1137+
// CHECK: tensor.insert_slice
1138+
// CHECK: scf.yield
1139+
func.func @hoist_linalg_ops(%a : tensor<128x128xf32>,
1140+
%b : tensor<128x128xf32>,
1141+
%c: tensor<128x128xf32>,
1142+
%lb : index,
1143+
%ub : index,
1144+
%step : index,
1145+
%output : tensor<?x128xf32>) -> tensor<?x128xf32> {
1146+
%final =
1147+
scf.for %i = %lb to %ub step %step iter_args(%acc = %output)
1148+
-> tensor<?x128xf32> {
1149+
%compute = linalg.generic #trait
1150+
ins(%a, %b : tensor<128x128xf32>, tensor<128x128xf32>)
1151+
outs(%c : tensor<128x128xf32>) {
1152+
^bb0(%in : f32, %in2 : f32, %in3 : f32):
1153+
%mul = arith.mulf %in, %in2 : f32
1154+
%add = arith.addf %mul, %in3 : f32
1155+
linalg.yield %in3 : f32
1156+
} -> tensor<128x128xf32>
1157+
1158+
%newacc = tensor.insert_slice %compute into
1159+
%output[%i, 0][128, 128][1, 1]
1160+
: tensor<128x128xf32> into tensor<?x128xf32>
1161+
scf.yield %newacc : tensor<?x128xf32>
1162+
}
1163+
1164+
func.return %final : tensor<?x128xf32>
1165+
}
1166+
1167+
// -----
1168+
1169+
#trait = {
1170+
indexing_maps = [
1171+
affine_map<(m, n, k) -> (m, k)>,
1172+
affine_map<(m, n, k) -> (k, n)>,
1173+
affine_map<(m, n, k) -> (m, n)>
1174+
],
1175+
iterator_types = ["parallel", "parallel", "reduction"]
1176+
}
1177+
1178+
// CHECK-LABEL: func @hoist_linalg_ops_div_by_zero
1179+
// CHECK-NOT: linalg.generic
1180+
// CHECK: scf.for
1181+
// CHECK: linalg.generic
1182+
// CHECK: tensor.insert_slice
1183+
// CHECK: scf.yield
1184+
func.func @hoist_linalg_ops_div_by_zero(%a : tensor<128x128xi32>,
1185+
%b : tensor<128x128xi32>,
1186+
%c: tensor<128x128xi32>,
1187+
%lb : index,
1188+
%ub : index,
1189+
%step : index,
1190+
%output : tensor<?x128xi32>) -> tensor<?x128xi32> {
1191+
%cst0 = arith.constant 0 : i32
1192+
%final =
1193+
scf.for %i = %lb to %ub step %step iter_args(%acc = %output)
1194+
-> tensor<?x128xi32> {
1195+
%compute = linalg.generic #trait
1196+
ins(%a, %b : tensor<128x128xi32>, tensor<128x128xi32>)
1197+
outs(%c : tensor<128x128xi32>) {
1198+
^bb0(%in : i32, %in2 : i32, %in3 : i32):
1199+
%div = arith.divui %in, %in2 : i32
1200+
%add = arith.addi %div, %in3 : i32
1201+
linalg.yield %in3 : i32
1202+
} -> tensor<128x128xi32>
1203+
1204+
%newacc = tensor.insert_slice %compute into
1205+
%output[%i, 0][128, 128][1, 1]
1206+
: tensor<128x128xi32> into tensor<?x128xi32>
1207+
scf.yield %newacc : tensor<?x128xi32>
1208+
}
1209+
1210+
func.return %final : tensor<?x128xi32>
1211+
}

mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ ArrayAttr {0}::getIndexingMaps() {{
656656
}
657657
)FMT";
658658

659-
// Implementations of fold and getEffects.
659+
// Implementations of fold, getEffects and getSpeculatability.
660660
// Parameters:
661661
// {0}: Class name
662662
const char structuredOpFoldersFormat[] = R"FMT(
@@ -669,6 +669,9 @@ void {0}::getEffects(SmallVectorImpl<
669669
if (hasPureTensorSemantics()) return;
670670
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
671671
}
672+
Speculation::Speculatability {0}::getSpeculatability() {{
673+
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
674+
}
672675
)FMT";
673676

674677
// Implementation of parse/print.

0 commit comments

Comments
 (0)