Skip to content

Commit 5a1622d

Browse files
committed
fix bufferizations that generate MapOp
1 parent 7a1d5f3 commit 5a1622d

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
579579
linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),
580580
/*init=*/tensorDestination);
581581
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
582+
linalgBody.addArgument(tensorType.getElementType(), loc);
582583

583584
// Create linalg::IndexOps.
584585
rewriter.setInsertionPointToStart(&linalgBody);
@@ -1068,6 +1069,7 @@ struct SplatOpInterface
10681069
/*inputs=*/ValueRange(),
10691070
/*init=*/*tensorAlloc);
10701071
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1072+
linalgBody.addArgument(tensorType.getElementType(), loc);
10711073

10721074
// Create linalg::IndexOps.
10731075
rewriter.setInsertionPointToStart(&linalgBody);

mlir/test/Dialect/Tensor/bufferize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
728728
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32>
729729
// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
730730
// CHECK: %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>)
731-
// CHECK: () {
731+
// CHECK: (%[[INIT:.*]]: f32) {
732732
// CHECK: linalg.yield %[[F]] : f32
733733
// CHECK: }
734734
// CHECK: return %[[MAPPED]] : tensor<?x3x?xf32>

0 commit comments

Comments
 (0)