Skip to content

Commit 856a8b5

Browse files
authored
[mlir][linalg] Add mixed precision folding pattern in vectorize_children_and_apply_patterns TD Op (#148684)
In case of mixed precision inputs, the inputs are generally casted to match output type thereby introduces arith.extFOp/extIOp instructions. Folding such pattern into vector.contract is desirable for HW having mixed precision ISA support. This patch adds folding of mixed precision pattern into vector.contract optionaly which can be enabled using attribute `fold_type_extensions_into_contract`.
1 parent 0720af8 commit 856a8b5

File tree

3 files changed

+461
-325
lines changed

3 files changed

+461
-325
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,6 +2348,9 @@ def VectorizeChildrenAndApplyPatternsOp :
23482348
operation that is contained inside the vectorization target.
23492349

23502350
This transformation supports the following attributes:
2351+
- `fold_type_extensions_into_contract`: a `UnitAttr` to enable the folding of
2352+
type extension operations into `vector.contract` to create a mixed precision
2353+
operation.
23512354
- `vectorize_padding`: a `UnitAttr` to activate the vectorization of
23522355
`tensor.pad` ops. Different pipelines may prefer to lower such ops to
23532356
loops.
@@ -2368,6 +2371,7 @@ def VectorizeChildrenAndApplyPatternsOp :
23682371
}];
23692372

23702373
let arguments = (ins TransformHandleTypeInterface:$target,
2374+
UnitAttr:$fold_type_extensions_into_contract,
23712375
UnitAttr:$vectorize_padding,
23722376
UnitAttr:$vectorize_nd_extract,
23732377
UnitAttr:$flatten_1d_depthwise_conv,
@@ -2381,6 +2385,7 @@ def VectorizeChildrenAndApplyPatternsOp :
23812385

23822386
let builders = [
23832387
OpBuilder<(ins "Value":$target,
2388+
CArg<"bool", "false">:$foldTypeExtensionsIntoContract,
23842389
CArg<"bool", "false">:$vectorizePadding,
23852390
CArg<"bool", "false">:$vectorizeNDExtract,
23862391
CArg<"bool", "false">:$flatten1DDepthwise)>

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3783,8 +3783,15 @@ LogicalResult TileUsingForallOp::verify() {
37833783

37843784
void transform::VectorizeChildrenAndApplyPatternsOp::build(
37853785
OpBuilder &builder, OperationState &result, Value target,
3786-
bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3786+
bool foldTypeExtensionsIntoContract, bool vectorizePadding,
3787+
bool vectorizeExtract, bool flatten1DDepthwiseConv) {
37873788
result.addOperands(target);
3789+
if (foldTypeExtensionsIntoContract) {
3790+
result.addAttribute(
3791+
VectorizeChildrenAndApplyPatternsOp::
3792+
getFoldTypeExtensionsIntoContractAttrName(result.name),
3793+
builder.getUnitAttr());
3794+
}
37883795
if (vectorizePadding) {
37893796
result.addAttribute(
37903797
VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
@@ -3875,6 +3882,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
38753882

38763883
patterns.add<CopyVectorizationPattern>(ctx);
38773884

3885+
if (getFoldTypeExtensionsIntoContract())
3886+
vector::populateFoldArithExtensionPatterns(patterns);
3887+
38783888
if (getVectorizePadding()) {
38793889
linalg::populatePadOpVectorizationPatterns(patterns);
38803890
// This creates an alternative path for lowering tensor.pad - by

0 commit comments

Comments
 (0)