Skip to content

Commit 71e3de8

Browse files
authored
[mlir][vector] Missing indices on vectorization of 1-d reduction to 1-ranked memref (#166959)
Vectorization of a 1-d reduction where the output variable is a 1-ranked memref can generate an invalid `vector.transfer_write` with no indices for the memref, e.g.: vector.transfer_write"(%vec, %buff) <{...}> : (vector<f32>, memref<1xf32>) -> () This patch solves the problem by providing the expected amount of indices (i.e. matching the rank of the memref).
1 parent c62fc06 commit 71e3de8

File tree

2 files changed

+64
-12
lines changed

2 files changed

+64
-12
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -746,12 +746,12 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
746746
auto vectorType = state.getCanonicalVecType(
747747
getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
748748

749+
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
750+
arith::ConstantIndexOp::create(rewriter, loc, 0));
751+
749752
Operation *write;
750753
if (vectorType.getRank() > 0) {
751754
AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
752-
SmallVector<Value> indices(
753-
linalgOp.getRank(outputOperand),
754-
arith::ConstantIndexOp::create(rewriter, loc, 0));
755755
value = broadcastIfNeeded(rewriter, value, vectorType);
756756
assert(value.getType() == vectorType && "Incorrect type");
757757
write = vector::TransferWriteOp::create(
@@ -762,7 +762,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
762762
value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
763763
assert(value.getType() == vectorType && "Incorrect type");
764764
write = vector::TransferWriteOp::create(rewriter, loc, value,
765-
outputOperand->get(), ValueRange{});
765+
outputOperand->get(), indices);
766766
}
767767

768768
write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);

mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,23 +1481,23 @@ module attributes {transform.with_named_sequence} {
14811481

14821482
// -----
14831483

1484-
// CHECK-LABEL: func @reduce_1d(
1485-
// CHECK-SAME: %[[A:.*]]: tensor<32xf32>
1486-
func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
1484+
// CHECK-LABEL: func @reduce_to_rank_0(
1485+
// CHECK-SAME: %[[SRC:.*]]: tensor<32xf32>
1486+
func.func @reduce_to_rank_0(%arg0: tensor<32xf32>) -> tensor<f32> {
14871487
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f32
14881488
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
14891489
%f0 = arith.constant 0.000000e+00 : f32
14901490

1491-
// CHECK: %[[init:.*]] = tensor.empty() : tensor<f32>
1491+
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<f32>
14921492
%0 = tensor.empty() : tensor<f32>
14931493

14941494
%1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<f32>) -> tensor<f32>
1495-
// CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
1495+
// CHECK: %[[R:.*]] = vector.transfer_read %[[SRC]][%[[C0]]]
14961496
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
1497-
// CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[F0]] [0]
1497+
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[R]], %[[F0]] [0]
14981498
// CHECK-SAME: : vector<32xf32> to f32
1499-
// CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
1500-
// CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[init]][]
1499+
// CHECK: %[[RED_V1:.*]] = vector.broadcast %[[RED]] : f32 to vector<f32>
1500+
// CHECK: %[[RES:.*]] = vector.transfer_write %[[RED_V1]], %[[INIT]][]
15011501
// CHECK-SAME: : vector<f32>, tensor<f32>
15021502
%2 = linalg.generic {
15031503
indexing_maps = [affine_map<(d0) -> (d0)>,
@@ -1523,6 +1523,58 @@ module attributes {transform.with_named_sequence} {
15231523
}
15241524

15251525

1526+
// -----
1527+
1528+
// CHECK-LABEL: func @reduce_to_rank_1(
1529+
// CHECK-SAME: %[[SRC:.*]]: tensor<32xf32>
1530+
func.func @reduce_to_rank_1(%arg0: tensor<32xf32>) -> tensor<1xf32> {
1531+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
1532+
// CHECK-DAG: %[[F0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
1533+
%f0 = arith.constant 0.000000e+00 : f32
1534+
1535+
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1xf32>
1536+
%0 = tensor.empty() : tensor<1xf32>
1537+
1538+
// CHECK: %[[INIT_ZERO:.*]] = vector.transfer_write %[[F0]], %[[INIT]][%[[C0]]]
1539+
// CHECK-SAME: : vector<1xf32>, tensor<1xf32>
1540+
%1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<1xf32>) -> tensor<1xf32>
1541+
1542+
// CHECK: %[[R:.*]] = vector.transfer_read %[[SRC]][%[[C0]]]
1543+
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
1544+
// CHECK: %[[INIT_ZERO_VEC:.*]] = vector.transfer_read %[[INIT_ZERO]][%[[C0]]]
1545+
// CHECK-SAME: : tensor<1xf32>, vector<f32>
1546+
// CHECK: %[[INIT_ZERO_SCL:.*]] = vector.extract %[[INIT_ZERO_VEC]][]
1547+
// CHECK-SAME: : f32 from vector<f32>
1548+
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[R]], %[[INIT_ZERO_SCL]] [0]
1549+
// CHECK-SAME: : vector<32xf32> to f32
1550+
// CHECK: %[[RED_V1:.*]] = vector.broadcast %[[RED]] : f32 to vector<f32>
1551+
// CHECK: vector.transfer_write %[[RED_V1]], %[[INIT_ZERO]][%[[C0]]]
1552+
// CHECK-SAME: : vector<f32>, tensor<1xf32>
1553+
1554+
%2 = linalg.generic {
1555+
indexing_maps = [affine_map<(d0) -> (d0)>,
1556+
affine_map<(d0) -> (0)>],
1557+
iterator_types = ["reduction"]}
1558+
ins(%arg0 : tensor<32xf32>)
1559+
outs(%1 : tensor<1xf32>) {
1560+
^bb0(%a: f32, %b: f32):
1561+
%3 = arith.addf %a, %b : f32
1562+
linalg.yield %3 : f32
1563+
} -> tensor<1xf32>
1564+
1565+
return %2 : tensor<1xf32>
1566+
}
1567+
1568+
module attributes {transform.with_named_sequence} {
1569+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1570+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1571+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
1572+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
1573+
transform.yield
1574+
}
1575+
}
1576+
1577+
15261578
// -----
15271579

15281580
// This test checks that vectorization does not occur when an input indexing map

0 commit comments

Comments
 (0)