Skip to content

Commit 6ca4fe6

Browse files
committed
[mlir][nfc] Make vectorize_nd_extract optional
Depends on: D157774 Differential Revision: https://reviews.llvm.org/D159360
1 parent 71ca53b commit 6ca4fe6

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2114,7 +2114,7 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
21142114

21152115
let arguments = (ins TransformHandleTypeInterface:$target,
21162116
Variadic<TransformHandleTypeInterface>:$vector_sizes,
2117-
UnitAttr:$vectorize_nd_extract,
2117+
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
21182118
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
21192119
$scalable_sizes,
21202120
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
@@ -2126,7 +2126,8 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
21262126
`vector_sizes` custom<DynamicIndexList>($vector_sizes,
21272127
$static_vector_sizes,
21282128
type($vector_sizes),
2129-
$scalable_sizes)
2129+
$scalable_sizes) |
2130+
`vectorize_nd_extract` $vectorize_nd_extract
21302131
)
21312132
attr-dict
21322133
`:` type($target)

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3231,7 +3231,9 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
32313231

32323232
if (failed(linalg::vectorize(rewriter, target, vectorSizes,
32333233
getScalableSizes(),
3234-
getVectorizeNdExtract()))) {
3234+
getVectorizeNdExtract().has_value()
3235+
? getVectorizeNdExtract().value()
3236+
: false))) {
32353237
return mlir::emitSilenceableFailure(target->getLoc())
32363238
<< "Attempted to vectorize, but failed";
32373239
}

mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguou
2828
transform.sequence failures(propagate) {
2929
^bb1(%arg1: !transform.any_op):
3030
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
31-
transform.structured.masked_vectorize %0 vector_sizes [1, 4] { vectorize_nd_extract } : !transform.any_op
31+
transform.structured.masked_vectorize %0 vector_sizes [1, 4] vectorize_nd_extract : !transform.any_op
3232
}
3333

3434
// -----
@@ -83,7 +83,7 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguo
8383
transform.sequence failures(propagate) {
8484
^bb1(%arg1: !transform.any_op):
8585
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
86-
transform.structured.masked_vectorize %0 vector_sizes [1, 4] { vectorize_nd_extract } : !transform.any_op
86+
transform.structured.masked_vectorize %0 vector_sizes [1, 4] vectorize_nd_extract : !transform.any_op
8787
}
8888

8989
// -----
@@ -121,7 +121,7 @@ func.func @masked_vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tenso
121121
transform.sequence failures(propagate) {
122122
^bb1(%arg1: !transform.any_op):
123123
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
124-
transform.structured.masked_vectorize %0 vector_sizes [1, 4] { vectorize_nd_extract } : !transform.any_op
124+
transform.structured.masked_vectorize %0 vector_sizes [1, 4] vectorize_nd_extract : !transform.any_op
125125
}
126126

127127
// -----
@@ -176,7 +176,7 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_gather(%
176176
transform.sequence failures(propagate) {
177177
^bb1(%arg1: !transform.any_op):
178178
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
179-
transform.structured.masked_vectorize %0 vector_sizes [1, 4] { vectorize_nd_extract } : !transform.any_op
179+
transform.structured.masked_vectorize %0 vector_sizes [1, 4] vectorize_nd_extract : !transform.any_op
180180
}
181181

182182
// -----
@@ -226,7 +226,7 @@ func.func @extract_masked_vectorize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf3
226226
transform.sequence failures(propagate) {
227227
^bb1(%arg1: !transform.any_op):
228228
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
229-
transform.structured.masked_vectorize %0 vector_sizes [3, 3] { vectorize_nd_extract } : !transform.any_op
229+
transform.structured.masked_vectorize %0 vector_sizes [3, 3] vectorize_nd_extract : !transform.any_op
230230
}
231231

232232
// -----
@@ -269,5 +269,5 @@ func.func @tensor_extract_dynamic_shape(%arg1: tensor<123x321xf32>, %arg2: tenso
269269
transform.sequence failures(propagate) {
270270
^bb1(%arg1: !transform.any_op):
271271
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
272-
transform.structured.masked_vectorize %0 vector_sizes [1, 3, 8] { vectorize_nd_extract } : !transform.any_op
272+
transform.structured.masked_vectorize %0 vector_sizes [1, 3, 8] vectorize_nd_extract : !transform.any_op
273273
}

0 commit comments

Comments
 (0)