Skip to content

Commit d193ac4

Browse files
authored
[mlir][vector] Drop inner unit dims for xWrite on dynamic shapes. (#80725)
This is part of 66347e5 The regression in downstream projects is about transfer_read patterns, which needs more investigation. Add the support for transfer_write for now.
1 parent 9a5fb74 commit d193ac4

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,7 +1318,7 @@ class DropInnerMostUnitDimsTransferWrite
13181318
return failure();
13191319

13201320
auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1321-
if (!srcType || !srcType.hasStaticShape())
1321+
if (!srcType)
13221322
return failure();
13231323

13241324
if (!writeOp.getPermutationMap().isMinorIdentity())
@@ -1341,20 +1341,23 @@ class DropInnerMostUnitDimsTransferWrite
13411341
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
13421342
targetType.getElementType());
13431343

1344+
Location loc = writeOp.getLoc();
1345+
SmallVector<OpFoldResult> sizes =
1346+
memref::getMixedSizes(rewriter, loc, writeOp.getSource());
1347+
SmallVector<OpFoldResult> offsets(srcType.getRank(),
1348+
rewriter.getIndexAttr(0));
1349+
SmallVector<OpFoldResult> strides(srcType.getRank(),
1350+
rewriter.getIndexAttr(1));
13441351
MemRefType resultMemrefType =
13451352
getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
1346-
SmallVector<int64_t> offsets(srcType.getRank(), 0);
1347-
SmallVector<int64_t> strides(srcType.getRank(), 1);
13481353
ArrayAttr inBoundsAttr =
13491354
writeOp.getInBounds()
13501355
? rewriter.getArrayAttr(
13511356
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
13521357
: ArrayAttr();
13531358

1354-
Location loc = writeOp.getLoc();
13551359
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1356-
loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
1357-
strides);
1360+
loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
13581361
auto permMap = getTransferMinorIdentityMap(
13591362
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
13601363

mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,27 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32,
119119

120120
// -----
121121

122+
func.func @outer_dyn_drop_inner_most_dim_for_transfer_write(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
123+
%c0 = arith.constant 0 : index
124+
vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0]
125+
{in_bounds = [true, true, true, true]}
126+
: vector<1x16x16x1xf32>, memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
127+
return
128+
}
129+
// CHECK: func.func @outer_dyn_drop_inner_most_dim_for_transfer_write
130+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
131+
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
132+
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
133+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
134+
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
135+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0, 0] [%[[D0]], 512, 16, 1]
136+
// CHECK-SAME: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<?x512x16xf32, strided<[8192, 16, 1], offset: ?>>
137+
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
138+
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
139+
// CHECK-SAME: [%[[IDX]], %[[C0]], %[[C0]]]
140+
141+
// -----
142+
122143
func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) {
123144
%c0 = arith.constant 0 : index
124145
vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]

0 commit comments

Comments
 (0)