Skip to content

Commit a2acb2f

Browse files
authored
[mlir][linalg] Fix vectorization of tensor.extract (#118105)
The example below demonstrates a "scalar read followed by a broadcast" pattern for `tensor.extract`: ```mlir #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func.func @scalar_broadcast( %init : tensor<1x1x3xi32>, %src: tensor<1x3x2x4xi32>, %idx :index) -> tensor<1x1x3xi32> { %c0 = arith.constant 0 :index %res = linalg.generic { indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%init : tensor<1x1x3xi32>) { ^bb0(%out: i32): %val = tensor.extract %src[%idx, %idx, %idx, %idx] : tensor<1x3x2x4xi32> linalg.yield %val : i32 } -> tensor<1x1x3xi32> return %res : tensor<1x1x3xi32> } ``` The default masking path within the Linalg vectorizer, which assumes an identity masking map, is not suitable here. Indeed: * identity != broadcast. This patch ensures masking is handled in the `vectorizeTensorExtract` hook, which has the necessary context for proper handling. Fixes #116197
1 parent c7ef0ac commit a2acb2f

File tree

3 files changed

+78
-14
lines changed

3 files changed

+78
-14
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1165,8 +1165,18 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11651165
loc, resultType, extractOp.getTensor(), transferReadIdxs,
11661166
permutationMap, inBounds);
11671167

1168+
// Mask this broadcasting xfer_read here rather than relying on the generic
1169+
// path (the generic path assumes identity masking map, which wouldn't be
1170+
// valid here).
1171+
SmallVector<int64_t> readMaskShape = {1};
1172+
auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1173+
auto allTrue = rewriter.create<vector::ConstantMaskOp>(
1174+
loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1175+
auto *maskedReadOp =
1176+
mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1177+
11681178
LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1169-
return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1179+
return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
11701180
}
11711181

11721182
// 2b. Handle contiguous access.

mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,55 @@ module attributes {transform.with_named_sequence} {
425425
transform.yield
426426
}
427427
}
428+
429+
// -----
430+
431+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
432+
func.func @scalar_broadcast(%init : tensor<1x1x3xi32>, %src: tensor<1x3x2x4xi32>, %idx :index) -> tensor<1x1x3xi32> {
433+
434+
%c0 = arith.constant 0 :index
435+
436+
%res = linalg.generic {
437+
indexing_maps = [#map],
438+
iterator_types = ["parallel", "parallel", "parallel"]}
439+
outs(%init : tensor<1x1x3xi32>) {
440+
^bb0(%out: i32):
441+
%val = tensor.extract %src[%idx, %idx, %idx, %idx] : tensor<1x3x2x4xi32>
442+
linalg.yield %val : i32
443+
} -> tensor<1x1x3xi32>
444+
445+
return %res : tensor<1x1x3xi32>
446+
}
447+
448+
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (0, 0, 0)>
449+
// CHECK-LABEL: func.func @scalar_broadcast(
450+
// CHECK-SAME: %[[INIT:.*]]: tensor<1x1x3xi32>,
451+
// CHECK-SAME: %[[SRC:.*]]: tensor<1x3x2x4xi32>,
452+
// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<1x1x3xi32> {
453+
454+
/// Compute the mask for saving the final result
455+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
456+
// CHECK: %[[C1_2:.*]] = arith.constant 1 : index
457+
// CHECK: %[[C3:.*]] = arith.constant 3 : index
458+
// CHECK: %[[MASK_RES:.*]] = vector.create_mask %[[C1]], %[[C1_2]], %[[C3]] : vector<1x1x4xi1>
459+
460+
/// Read and broadcast the scalar
461+
// CHECK: %[[PAD:.*]] = arith.constant 0 : i32
462+
// CHECK: %[[MASK_READ:.*]] = vector.constant_mask [1] : vector<1xi1>
463+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK_READ]] {
464+
// CHECK-SAME: vector.transfer_read %[[SRC]]{{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %[[PAD]]
465+
// CHECK-SAME: {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<1x3x2x4xi32>, vector<1x1x4xi32>
466+
// CHECK-SAME: } : vector<1xi1> -> vector<1x1x4xi32>
467+
468+
/// Save the result in the output tensor
469+
// CHECK: vector.mask %[[MASK_RES]] {
470+
// CHECK-SAME: vector.transfer_write %[[READ]], %[[INIT]]{{.*}} {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x3xi32>
471+
// CHECK-SAME: } : vector<1x1x4xi1> -> tensor<1x1x3xi32>
472+
473+
module attributes {transform.with_named_sequence} {
474+
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
475+
%0 = transform.structured.match ops{["linalg.generic"]} in %module : (!transform.any_op) -> !transform.any_op
476+
transform.structured.vectorize %0 vector_sizes [1, 1, 4] {vectorize_nd_extract} : !transform.any_op
477+
transform.yield
478+
}
479+
}

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ module attributes {transform.with_named_sequence} {
6666
// -----
6767

6868
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
69-
func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
69+
func.func @vectorize_nd_tensor_extract_scalar_broadcast(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
7070
%c0 = arith.constant 1 : index
7171
%c1 = arith.constant 2 : index
7272
%2 = linalg.generic {
@@ -80,17 +80,17 @@ func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg
8080
return %2 : tensor<1x1x3xf32>
8181
}
8282

83-
// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (0, 0, 0)>
84-
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx(
83+
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
84+
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_scalar_broadcast(
8585
// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x3xf32>,
8686
// CHECK-SAME: %[[ARG_1:.*]]: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
8787
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
8888
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
89-
// CHECK-DAG: %[[C0_f32_2:.*]] = arith.constant 0.000000e+00 : f32
90-
// CHECK-DAG: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
91-
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG_0]][%[[C1]], %[[C2]]], %[[C0_f32]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<3x3xf32>, vector<1x1x3xf32>
92-
// CHECK: %[[C0_4:.*]] = arith.constant 0 : index
93-
// CHECK: vector.transfer_write %[[READ]], %[[ARG_1]][%[[C0_4]], %[[C0_4]], %[[C0_4]]] : vector<1x1x3xf32>, tensor<1x1x3xf32>
89+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
90+
// CHECK: %[[MASK:.*]] = vector.constant_mask [1] : vector<1xi1>
91+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C1]], %[[C2]]], {{.*}} {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<3x3xf32>, vector<1x1x3xf32> } : vector<1xi1> -> vector<1x1x3xf32>
92+
// CHECK: %[[C0_2:.*]] = arith.constant 0 : index
93+
// CHECK: vector.transfer_write %[[READ]], %[[ARG_1]]{{\[}}%[[C0_2]], %[[C0_2]], %[[C0_2]]] : vector<1x1x3xf32>, tensor<1x1x3xf32>
9494

9595
module attributes {transform.with_named_sequence} {
9696
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -823,7 +823,7 @@ func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> t
823823
return %out:tensor<1x1x4xi32>
824824
}
825825

826-
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
826+
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
827827
// CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor(
828828
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
829829
// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
@@ -844,12 +844,14 @@ func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> t
844844
// CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
845845
// CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
846846
// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index
847-
// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
848-
// CHECK: %[[VAL_21:.*]] = vector.extract %[[VAL_20]][0] : index from vector<4xindex>
849-
// CHECK: %[[VAL_22:.*]] = arith.constant 0 : i32
850-
// CHECK: %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
847+
// CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
848+
// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_19]][0] : index from vector<4xindex>
849+
// CHECK: %[[VAL_21:.*]] = arith.constant 0 : i32
850+
// CHECK: %[[VAL_22:.*]] = vector.constant_mask [1] : vector<1xi1>
851+
// CHECK: %[[VAL_23:.*]] = vector.mask %[[VAL_22]] { vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_20]], %[[VAL_2]]], %[[VAL_21]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<15x1xi32>, vector<1x1x4xi32> } : vector<1xi1> -> vector<1x1x4xi32>
851852
// CHECK: %[[VAL_24:.*]] = arith.constant 0 : index
852853
// CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
854+
// CHECK: return %[[VAL_25]] : tensor<1x1x4xi32>
853855

854856
module attributes {transform.with_named_sequence} {
855857
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {

0 commit comments

Comments
 (0)