Skip to content

Commit 4c0f2a7

Browse files
committed
[mlir][vector] Convert vector.transfer_read to scalar load and broadcast
If we use vector.transfer_read to read from a 0-d value, we can convert it to memref.load from the 0-d value then broadcast the value to the target vector type. It can avoid generating vector operations breaking the requirements of convertVectorToMMAOps. The patterns in convertVectorToMMAOps expect all vector.transfer_read have 2-D vector types. Instead of %s0 = vector.transfer_read %base[] : memref<dtype> to vector<dtype> %s1 = vector.broadcast %s0 : vector<dtype> to vector<d0...d1 x dtype> Use %s0 = memref.load %base[] : memref<dtype> %s1 = vector.broadcast %s0 : dtype to vector<d0...d1 x dtype>
1 parent d6315a2 commit 4c0f2a7

File tree

2 files changed

+47
-11
lines changed

2 files changed

+47
-11
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -360,17 +360,35 @@ struct TransferOpReduceRank
360360
SmallVector<bool> newScalableDims(
361361
originalVecType.getScalableDims().take_back(reducedShapeRank));
362362

363-
VectorType newReadType = VectorType::get(
364-
newShape, originalVecType.getElementType(), newScalableDims);
365-
ArrayAttr newInBoundsAttr =
366-
op.getInBounds()
367-
? rewriter.getArrayAttr(
368-
op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
369-
: ArrayAttr();
370-
Value newRead = vector::TransferReadOp::create(
371-
rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
372-
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
373-
newInBoundsAttr);
363+
Value newRead;
364+
if (newShape.size() == 0 && newScalableDims.size() == 0) {
365+
// Handle the scalar case.
366+
// Convert
367+
// %val = vector.transfer_read %base[] : memref<dtype> to
368+
// vector<d0 x d1 x dtype>
369+
// into
370+
// %scalar = memref.load %base[] : memref<dtype>
371+
// %val = vector.broadcast %scalar : dtype to vector<d0 x d1 x dtype>
372+
Type baseType = op.getBase().getType();
373+
if (isa<MemRefType>(baseType)) {
374+
newRead = memref::LoadOp::create(rewriter, op.getLoc(), op.getBase(),
375+
op.getIndices());
376+
}
377+
}
378+
379+
if (!newRead) {
380+
VectorType newReadType = VectorType::get(
381+
newShape, originalVecType.getElementType(), newScalableDims);
382+
ArrayAttr newInBoundsAttr =
383+
op.getInBounds()
384+
? rewriter.getArrayAttr(
385+
op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
386+
: ArrayAttr();
387+
newRead = vector::TransferReadOp::create(
388+
rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
389+
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
390+
newInBoundsAttr);
391+
}
374392
return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType,
375393
newRead)
376394
.getVector();

mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,24 @@ func.func @xfer_read_minor_identitiy_bcast_dims(
388388
return %res : vector<8x4x2x3xf32>
389389
}
390390

391+
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_scalar
392+
// CHECK-SAME: %[[MEM:.*]]: memref<f32>) -> vector<8x4x2x3xf32> {
393+
// CHECK: %[[LOAD:.*]] = memref.load %[[MEM]][] : memref<f32>
394+
// CHECK: %[[BC:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<8x4x2x3xf32>
395+
// CHECK: return %[[BC]] : vector<8x4x2x3xf32>
396+
func.func @xfer_read_minor_identitiy_bcast_scalar(
397+
%mem: memref<f32>) -> vector<8x4x2x3xf32> {
398+
399+
%pad = arith.constant 0.000000e+00 : f32
400+
401+
%res = vector.transfer_read %mem[], %pad {
402+
in_bounds = [true, true, true, true],
403+
permutation_map = affine_map<() -> (0, 0, 0, 0)>
404+
} : memref<f32>, vector<8x4x2x3xf32>
405+
406+
return %res : vector<8x4x2x3xf32>
407+
}
408+
391409
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_scalable
392410
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
393411
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>

0 commit comments

Comments
 (0)