Skip to content

Commit df03118

Browse files
committed
address comments
1 parent ae69b88 commit df03118

File tree

2 files changed

+5
-11
lines changed

2 files changed

+5
-11
lines changed

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -689,16 +689,10 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
689689
collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
690690
reassociationMap, options.rankReductionStrategy);
691691

692-
<<<<<<< HEAD
693-
auto newPadOp = tensor::PadOp::create(
694-
rewriter, padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad,
695-
=======
696-
auto resultType = RankedTensorType::get(
697-
newResultShape, padOp.getResultType().getElementType()); //,
698-
// padOp.getResultType().getEncoding());
692+
auto newResultType = RankedTensorType::get(
693+
newResultShape, padOp.getResultType().getElementType());
699694
auto newPadOp = rewriter.create<tensor::PadOp>(
700-
padOp.getLoc(), /*result=*/resultType, collapsedSource, newLowPad,
701-
>>>>>>> 617b335dec04 ([MLIR] Specify outputType in pad op unitDim drop)
695+
padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad,
702696
newHighPad, paddingVal, padOp.getNofold());
703697

704698
Value dest = padOp.getResult();

mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,12 +1088,12 @@ func.func @drop_unit_dim_mixed_static_dynamic(%arg0: tensor<1x?xf32>) -> tensor<
10881088
}
10891089
// CHECK-LABEL: func @drop_unit_dim_mixed_static_dynamic
10901090
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1091-
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
1091+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARGS:.*]] : tensor<1x?xf32> into tensor<?xf32>
10921092
// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSE]] low[1] high[0] {
10931093
// CHECK: ^bb0(%[[IDX:.*]]: index):
10941094
// CHECK: tensor.yield %[[CST]] : f32
10951095
// CHECK: } : tensor<?xf32> to tensor<16xf32>
1096-
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[PADDED]]
1096+
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, 16] : tensor<16xf32> into tensor<1x16xf32>
10971097
// CHECK: return %[[EXPANDED]] : tensor<1x16xf32>
10981098

10991099
// -----

0 commit comments

Comments
 (0)