Skip to content

Commit c5b3c39

Browse files
committed
add test
1 parent eee8805 commit c5b3c39

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,11 +408,10 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
408408
ShapedType collapsedType;
409409
if (stripMinedType.isa<TensorType>()) {
410410
collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
411-
stripMinedType.cast<RankedTensorType>(),
412-
packingMetadata.reassociations);
411+
cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
413412
} else if (stripMinedType.isa<MemRefType>()) {
414413
collapsedType = memref::CollapseShapeOp::inferCollapsedType(
415-
stripMinedType.cast<MemRefType>(), packingMetadata.reassociations);
414+
cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
416415
}
417416

418417
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm

mlir/test/Dialect/Linalg/loops.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,3 +942,19 @@ func.func @transpose(%input: memref<?xf32>,
942942
// CHECKPARALLEL: }
943943
// CHECKPARALLEL: return
944944
// CHECKPARALLEL: }
945+
946+
// Test that we can lower all the way to LLVM without crashing, don't check results here.
947+
func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
948+
%dest = memref.alloc() : memref<8x16x8x32xf32>
949+
%packed = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
950+
into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
951+
return %packed : memref<8x16x8x32xf32>
952+
}
953+
954+
// Test that we can lower all the way to LLVM without crashing, don't check results here.
955+
func.func @unpack_memref(%source: memref<16x8x8x32xf32>) -> memref<128x256xf32> {
956+
%dest = memref.alloc() : memref<128x256xf32>
957+
%unpacked = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
958+
into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
959+
return %unpacked : memref<128x256xf32>
960+
}

0 commit comments

Comments
 (0)