Skip to content

Commit a5d01df

Browse files
committed
fix upon review
1 parent 714e4c4 commit a5d01df

File tree

8 files changed

+71
-67
lines changed

8 files changed

+71
-67
lines changed

mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#define LINALG_IR_RELAYOUTOPINTERFACE
1111

1212
include "mlir/Interfaces/DestinationStyleOpInterface.td"
13-
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
1413
include "mlir/IR/OpBase.td"
1514

1615
def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1782,11 +1782,6 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17821782

17831783
static MemRefType computeCollapsedType(
17841784
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
1785-
static MemRefType
1786-
inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
1787-
static MemRefType
1788-
inferCollapsedType(MemRefType type,
1789-
SmallVector<ReassociationIndices> reassociation);
17901785
}];
17911786

17921787
let hasVerifier = 1;
@@ -1806,7 +1801,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
18061801
let summary = "store operation";
18071802
let description = [{
18081803
The `store` op stores an element into a memref at the specified indices.
1809-
1804+
18101805
The number of indices must match the rank of the memref. The indices must
18111806
be in-bounds: `0 <= idx < dim_size`
18121807

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5001,12 +5001,8 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
50015001
}
50025002

50035003
bool PackOp::isLikePad() {
5004-
if (auto packedTensorType =
5005-
llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
5006-
return isLikePadUnPad(*this, packedTensorType);
5007-
if (auto packedTensorType =
5008-
llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
5009-
return isLikePadUnPad(*this, packedTensorType);
5004+
auto packedTensorType = llvm::dyn_cast<ShapedType>((*this)->getResultTypes().front());
5005+
return isLikePadUnPad(*this, packedTensorType);
50105006
}
50115007

50125008
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
@@ -5042,6 +5038,9 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
50425038
if (!tensor::hasFoldableTensorCastOperand(op))
50435039
return failure();
50445040

5041+
if (!op.hasPureTensorSemantics())
5042+
return failure();
5043+
50455044
SmallVector<Type> newResultTypes(op->getResultTypes());
50465045
SmallVector<Value> newOperands =
50475046
tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
@@ -5310,6 +5309,9 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
53105309
if (!tensor::hasFoldableTensorCastOperand(op))
53115310
return failure();
53125311

5312+
if (!op.hasPureTensorSemantics())
5313+
return failure();
5314+
53135315
SmallVector<Type> newResultTypes(op->getResultTypes());
53145316
SmallVector<Value> newOperands =
53155317
tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);

mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,25 +171,27 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
171171
return success();
172172
}
173173

174-
LogicalResult matchAndRewrite(UnPackOp unpackOp,
174+
LogicalResult matchAndRewrite(UnPackOp unPackOp,
175175
PatternRewriter &rewriter) const override {
176-
ShapedType destType = unpackOp.getDestType();
177-
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
178-
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
179-
unpackOp.getStaticTiles())) &&
180-
!unpackOp.isLikeUnPad()) {
176+
if (!unPackOp.hasPureTensorSemantics())
177+
return failure();
178+
ShapedType destType = unPackOp.getDestType();
179+
if (failed(isUnpackOnInnerMostDim(rewriter, unPackOp)) &&
180+
failed(isPackOn1D(rewriter, unPackOp, destType.getShape(),
181+
unPackOp.getStaticTiles())) &&
182+
!unPackOp.isLikeUnPad()) {
181183
return failure();
182184
}
183185

184-
ShapedType sourceType = unpackOp.getSourceType();
186+
ShapedType sourceType = unPackOp.getSourceType();
185187
auto reassociation =
186188
getReassociationIndicesForReshape(sourceType, destType);
187189
if (!reassociation)
188190
return failure();
189191
Value collapsed = insertCollapse(
190-
rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
192+
rewriter, unPackOp.getLoc(), unPackOp.getSource(), destType,
191193
getReassociationIndicesAttribute(rewriter, *reassociation));
192-
rewriter.replaceOp(unpackOp, collapsed);
194+
rewriter.replaceOp(unPackOp, collapsed);
193195
return success();
194196
}
195197
};
@@ -426,6 +428,8 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
426428

427429
LogicalResult matchAndRewrite(UnPackOp unPackOp,
428430
PatternRewriter &rewriter) const override {
431+
if (!unPackOp.hasPureTensorSemantics())
432+
return failure();
429433
auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
430434
if (!linalgOp)
431435
return failure();
@@ -507,6 +511,8 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
507511

508512
LogicalResult matchAndRewrite(UnPackOp unPackOp,
509513
PatternRewriter &rewriter) const override {
514+
if (!unPackOp.hasPureTensorSemantics())
515+
return failure();
510516
// Check for tensor.empty source.
511517
auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>();
512518
if (!emptyOp)

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
410410
collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
411411
cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
412412
} else if (stripMinedType.isa<MemRefType>()) {
413-
collapsedType = memref::CollapseShapeOp::inferCollapsedType(
413+
collapsedType = memref::CollapseShapeOp::computeCollapsedType(
414414
cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
415415
}
416416

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2526,34 +2526,34 @@ MemRefType CollapseShapeOp::computeCollapsedType(
25262526
srcType.getMemorySpace());
25272527
}
25282528

2529-
MemRefType
2530-
CollapseShapeOp::inferCollapsedType(MemRefType type,
2531-
ArrayRef<AffineMap> reassociation) {
2532-
auto shape = type.getShape();
2533-
SmallVector<int64_t, 4> newShape;
2534-
assert(isReassociationValid(reassociation) && "invalid reassociation");
2535-
unsigned currentDim = 0;
2536-
for (AffineMap m : reassociation) {
2537-
unsigned dim = m.getNumResults();
2538-
auto band = shape.slice(currentDim, dim);
2539-
int64_t size = 1;
2540-
if (llvm::is_contained(band, ShapedType::kDynamic))
2541-
size = ShapedType::kDynamic;
2542-
else
2543-
for (unsigned d = 0; d < dim; ++d)
2544-
size *= shape[currentDim + d];
2545-
newShape.push_back(size);
2546-
currentDim += dim;
2547-
}
2548-
return MemRefType::get(newShape, type.getElementType());
2549-
}
2550-
2551-
MemRefType CollapseShapeOp::inferCollapsedType(
2552-
MemRefType type, SmallVector<ReassociationIndices> reassociation) {
2553-
return inferCollapsedType(
2554-
type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
2555-
type.getContext(), reassociation)));
2556-
}
2529+
// MemRefType
2530+
// CollapseShapeOp::inferCollapsedType(MemRefType type,
2531+
// ArrayRef<AffineMap> reassociation) {
2532+
// auto shape = type.getShape();
2533+
// SmallVector<int64_t, 4> newShape;
2534+
// assert(isReassociationValid(reassociation) && "invalid reassociation");
2535+
// unsigned currentDim = 0;
2536+
// for (AffineMap m : reassociation) {
2537+
// unsigned dim = m.getNumResults();
2538+
// auto band = shape.slice(currentDim, dim);
2539+
// int64_t size = 1;
2540+
// if (llvm::is_contained(band, ShapedType::kDynamic))
2541+
// size = ShapedType::kDynamic;
2542+
// else
2543+
// for (unsigned d = 0; d < dim; ++d)
2544+
// size *= shape[currentDim + d];
2545+
// newShape.push_back(size);
2546+
// currentDim += dim;
2547+
// }
2548+
// return MemRefType::get(newShape, type.getElementType());
2549+
// }
2550+
2551+
// MemRefType CollapseShapeOp::inferCollapsedType(
2552+
// MemRefType type, SmallVector<ReassociationIndices> reassociation) {
2553+
// return inferCollapsedType(
2554+
// type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
2555+
// type.getContext(), reassociation)));
2556+
// }
25572557

25582558
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
25592559
ArrayRef<ReassociationIndices> reassociation,

mlir/test/Dialect/Linalg/loops.mlir

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -942,19 +942,3 @@ func.func @transpose(%input: memref<?xf32>,
942942
// CHECKPARALLEL: }
943943
// CHECKPARALLEL: return
944944
// CHECKPARALLEL: }
945-
946-
// Test that we can lower all the way to LLVM without crashing, don't check results here.
947-
func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
948-
%dest = memref.alloc() : memref<8x16x8x32xf32>
949-
%packed = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
950-
into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
951-
return %packed : memref<8x16x8x32xf32>
952-
}
953-
954-
// Test that we can lower all the way to LLVM without crashing, don't check results here.
955-
func.func @unpack_memref(%source: memref<16x8x8x32xf32>) -> memref<128x256xf32> {
956-
%dest = memref.alloc() : memref<128x256xf32>
957-
%unpacked = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
958-
into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
959-
return %unpacked : memref<128x256xf32>
960-
}

mlir/test/Dialect/Linalg/roundtrip.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,3 +706,21 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
706706
// CHECK-LABEL: func @conv2d_channel_first_q_promote(
707707
// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8)
708708
// CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
709+
710+
// -----
711+
// Test that we can lower all the way to LLVM without crashing, don't check results here.
712+
func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
713+
%dest = memref.alloc() : memref<8x16x8x32xf32>
714+
%packed = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
715+
into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
716+
return %packed : memref<8x16x8x32xf32>
717+
}
718+
719+
// -----
720+
// Test that we can lower all the way to LLVM without crashing, don't check results here.
721+
func.func @unpack_memref(%source: memref<16x8x8x32xf32>) -> memref<128x256xf32> {
722+
%dest = memref.alloc() : memref<128x256xf32>
723+
%unpacked = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
724+
into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
725+
return %unpacked : memref<128x256xf32>
726+
}

0 commit comments

Comments
 (0)