Skip to content

Commit 1d06346

Browse files
authored
[TPP-Pipeline] New flag option to disable vnni packing for packed types. (#1077)
This `patch` adds new flag option to disable vnni layout for packed types.
1 parent 075fdaa commit 1d06346

File tree

4 files changed

+17
-3
lines changed

4 files changed

+17
-3
lines changed

include/TPP/PassBundles.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def DefaultTppPasses : Pass<"default-tpp-passes", "ModuleOp"> {
5252
Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose",
5353
"bool", /*default=*/"false",
5454
"Lower non-constant packs and unpacks reverting any dim permutations.">,
55+
Option<"disableVnniPacking", "disable-vnni-packing",
56+
"bool", /*default=*/"false",
57+
"Disables VNNI packing for packed types.">,
5558
ListOption<"registerBlocking", "registerBlocking",
5659
"unsigned", "Register blocking tile sizes for brgemm operation.">,
5760

@@ -71,7 +74,10 @@ def TppMapping : Pass<"tpp-mapping", "ModuleOp"> {
7174
let options= [
7275
Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose",
7376
"bool", /*default=*/"false",
74-
"Lower non-constant packs and unpacks reverting any dim permutations.">
77+
"Lower non-constant packs and unpacks reverting any dim permutations.">,
78+
Option<"disableVnniPacking", "disable-vnni-packing",
79+
"bool", /*default=*/"false",
80+
"Disables VNNI packing for packed types.">
7581
];
7682
}
7783

lib/TPP/DefaultPipeline.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ llvm::cl::opt<bool> lowerPackUnpackWithoutTranspose(
6767
llvm::cl::desc("Lower packs and unpacks reverting any dim permutations"),
6868
llvm::cl::init(false));
6969

70+
llvm::cl::opt<bool> disableVnniPacking("disable-vnni-packing",
71+
llvm::cl::desc("Disables VNNI packing for packed types"),
72+
llvm::cl::init(false));
7073

7174
llvm::cl::list<unsigned>
7275
registerBlocking("registerBlocking", llvm::cl::desc("Register blocking tile sizes for brgemm operation"),
@@ -156,6 +159,7 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
156159
tppDefaultOptions.linalgToVector = linalgToVector;
157160
tppDefaultOptions.vectorToXSMM = vectorToXSMM;
158161
tppDefaultOptions.lowerPackUnpackWithoutTranspose = lowerPackUnpackWithoutTranspose;
162+
tppDefaultOptions.disableVnniPacking = disableVnniPacking;
159163
tppDefaultOptions.registerBlocking =
160164
SmallVector<unsigned>{registerBlocking.begin(), registerBlocking.end()};
161165
tppDefaultOptions.vectorToKernel = vectorToKernel;

lib/TPP/DefaultTppPasses.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ struct DefaultTppPasses
117117
pm.addPass(createRewriteBatchMatmulToMatmul());
118118

119119
// Applies a set of passes at the linalg level to fuse and pack.
120-
TppMappingOptions tppMappingOptions{lowerPackUnpackWithoutTranspose};
120+
TppMappingOptions tppMappingOptions{lowerPackUnpackWithoutTranspose,
121+
disableVnniPacking};
121122
pm.addPass(createTppMapping(tppMappingOptions));
122123

123124
// Generalize linalg.pack and linalg.unpack.

lib/TPP/PassBundles/TppMapping.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ struct TppMapping : public tpp::impl::TppMappingBase<TppMapping>,
6262
pm.addPass(createPackConv2DNchwFchw());
6363
pm.addPass(createRewriteConvToMatmulOrBrgemm());
6464
pm.addPass(createPackMatmul());
65-
pm.addPass(createPackVNNI());
65+
66+
if (!disableVnniPacking) {
67+
pm.addPass(createPackVNNI());
68+
}
6669

6770
if (lowerPackUnpackWithoutTranspose) {
6871
pm.addPass(createLowerPacksAndUnpacksWithoutTranspose());

0 commit comments

Comments
 (0)