Skip to content

Commit 6ec0040

Browse files
committed
remove duplicates by marking operands not written-to
1 parent 0c44188 commit 6ec0040

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,7 +1059,7 @@ struct ConcatOpInterface
10591059

10601060
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
10611061
const AnalysisState &state) const {
1062-
return true;
1062+
return false;
10631063
}
10641064

10651065
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
@@ -1069,7 +1069,7 @@ struct ConcatOpInterface
10691069

10701070
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
10711071
const AnalysisState &state) const {
1072-
return {{op->getResult(0), BufferRelation::Equivalent}};
1072+
return {};
10731073
}
10741074

10751075
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -1109,7 +1109,7 @@ struct ConcatOpInterface
11091109
sizes.push_back(rewriter.getIndexAttr(dimSize));
11101110
}
11111111

1112-
int concatDimOffset = 0;
1112+
int64_t concatDimOffset = 0;
11131113
for (auto operand : concatOp.getInputs()) {
11141114
// Get the buffer for the operand.
11151115
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);

mlir/test/Dialect/Tensor/bufferize.mlir

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -618,13 +618,11 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
618618
// CHECK-LABEL: func @tensor.concat(
619619
// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
620620
// CHECK: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
621-
// CHECK: memref.copy %[[F_MEMREF]], %[[F_ALLOC:.*]] :
622-
// CHECK: memref.copy %[[F_MEMREF]], %[[F_ALLOC_2:.*]] :
623621
// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<16xf32>
624622
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0] [8] [1]
625-
// CHECK: memref.copy %[[F_ALLOC]], %[[SUBVIEW1]]
623+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
626624
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][8] [8] [1]
627-
// CHECK: memref.copy %[[F_ALLOC_2]], %[[SUBVIEW2]]
625+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW2]]
628626
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
629627
// CHECK: return %[[RET]]
630628
// CHECK: }
@@ -640,13 +638,11 @@ func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
640638
// CHECK-SAME: %[[G:.*]]: tensor<8x5xf32>
641639
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
642640
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
643-
// CHECK: memref.copy %[[F_MEMREF]], %[[F_ALLOC:.*]] :
644-
// CHECK: memref.copy %[[G_MEMREF]], %[[F_ALLOC_2:.*]] :
645641
// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<8x9xf32>
646642
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, 4] [1, 1]
647-
// CHECK: memref.copy %[[F_ALLOC]], %[[SUBVIEW1]]
643+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
648644
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, 4] [8, 5] [1, 1]
649-
// CHECK: memref.copy %[[F_ALLOC_2]], %[[SUBVIEW2]]
645+
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
650646
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
651647
// CHECK: return %[[RET]]
652648
// CHECK: }

0 commit comments

Comments
 (0)