Skip to content

Commit 25e6576

Browse files
committed
Check linalg.generic arguments to prevent crashing when they are deleted
1 parent f7e8be7 commit 25e6576

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,15 +1199,15 @@ static void getGenericEffectsImpl(
11991199
&effects,
12001200
LinalgOp linalgOp) {
12011201
for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1202-
if (!llvm::isa<MemRefType>(operand.getType()))
1202+
if (!operand || !llvm::isa<MemRefType>(operand.getType()))
12031203
continue;
12041204
effects.emplace_back(
12051205
MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
12061206
/*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
12071207
}
12081208

12091209
for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1210-
if (!llvm::isa<MemRefType>(operand.get().getType()))
1210+
if (!operand.get() || !llvm::isa<MemRefType>(operand.get().getType()))
12111211
continue;
12121212
if (linalgOp.payloadUsesValueFromOperand(&operand)) {
12131213
effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,

mlir/test/Transforms/remove-dead-values.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,32 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
7373

7474
// -----
7575

76+
// Checking that the arguments of linalg.generic are properly handled
77+
// All code below is removed as a result of the pass
78+
//
79+
#map = affine_map<(d0, d1, d2) -> (0, d1, d2)>
80+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
81+
module {
82+
func.func @main() {
83+
%cst_3 = arith.constant dense<54> : tensor<1x25x13xi32>
84+
%cst_7 = arith.constant dense<11> : tensor<1x25x13xi32>
85+
// CHECK-NOT: arith.constant
86+
%0 = tensor.empty() : tensor<1x25x13xi32>
87+
// CHECK-NOT: tensor
88+
%1 = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_3, %cst_7 : tensor<1x25x13xi32>, tensor<1x25x13xi32>) outs(%0 : tensor<1x25x13xi32>) {
89+
// CHECK-NOT: linalg.generic
90+
^bb0(%in: i32, %in_15: i32, %out: i32):
91+
%29 = arith.xori %in, %in_15 : i32
92+
// CHECK-NOT: arith.xori
93+
linalg.yield %29 : i32
94+
// CHECK-NOT: linalg.yield
95+
} -> tensor<1x25x13xi32>
96+
return
97+
}
98+
}
99+
100+
// -----
101+
76102
// Note that this cleanup cannot be done by the `canonicalize` pass.
77103
//
78104
// CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() {

0 commit comments

Comments
 (0)