diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 834c02126fa53..f24310ecd7beb 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -455,6 +455,8 @@ static LogicalResult generateLoopNestUsingForOp( rewriter.setInsertionPointToEnd(loop.getBody()); destinationTensors = loop.getRegionIterArgs(); } + if (loops.empty()) + return success(); SmallVector tiledResults; SmallVector> resultOffsets, resultSizes; @@ -463,9 +465,6 @@ static LogicalResult generateLoopNestUsingForOp( return rewriter.notifyMatchFailure( loc, "failed to generate inner tile loop body"); } - if (loops.empty()) - return success(); - assert(tiledResults.size() == destinationTensors.size() && "Number of results of body should be equal to number of iter args"); diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir index 7bac850d0b7fe..0466a7ba3e2ea 100644 --- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir @@ -266,3 +266,23 @@ func.func @tile_linalg_matmul( -> tensor<128x128xf32> return %0 : tensor<128x128xf32> } + +// ----- + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{op expected number of loops to tile (0) to match number of `loops` results (1)}} + %1, %loops = transform.structured.tile_using_for %0 tile_sizes [0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +func.func @tile_linalg_matmul( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> { + %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + return %0 : tensor<128x128xf32> +}