diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index d9840e3923c4f..c5b1a3b55126c 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1199,7 +1199,7 @@ static void getGenericEffectsImpl( &effects, LinalgOp linalgOp) { for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) { - if (!llvm::isa(operand.getType())) + if (!operand || !llvm::isa(operand.getType())) continue; effects.emplace_back( MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0, @@ -1207,7 +1207,7 @@ static void getGenericEffectsImpl( } for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { - if (!llvm::isa(operand.get().getType())) + if (!operand.get() || !llvm::isa(operand.get().getType())) continue; if (linalgOp.payloadUsesValueFromOperand(&operand)) { effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0, diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 9273ac01e7cce..fe7bcbc7c490b 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -73,6 +73,32 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index { // ----- +// Checking that the arguments of linalg.generic are properly handled +// All code below is removed as a result of the pass +// +#map = affine_map<(d0, d1, d2) -> (0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +module { + func.func @main() { + %cst_3 = arith.constant dense<54> : tensor<1x25x13xi32> + %cst_7 = arith.constant dense<11> : tensor<1x25x13xi32> + // CHECK-NOT: arith.constant + %0 = tensor.empty() : tensor<1x25x13xi32> + // CHECK-NOT: tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_3, %cst_7 : tensor<1x25x13xi32>, tensor<1x25x13xi32>) outs(%0 : tensor<1x25x13xi32>) { + // CHECK-NOT: linalg.generic + ^bb0(%in: i32, %in_15: i32, %out: i32): + %29 = arith.xori %in, %in_15 : i32 + // CHECK-NOT: arith.xori + linalg.yield %29 : i32 + // CHECK-NOT: linalg.yield + } -> tensor<1x25x13xi32> + return + } +} + +// ----- + // Note that this cleanup cannot be done by the `canonicalize` pass. // // CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() {