Skip to content

Commit e9bce17

Browse files
authored
Support lowering of vector.contract to amx for brgemm (#1017)
Fp32 brgemm can be lowered using FMAs but this can not be used for BF16 inputs. Intel AMX has TMUL functional unit which provides tile registers of size 16x16 for bf16 data type and corresponding load, store, multiply instructions. This pass lowers the tiled brgemm from vector dialect to AMX dialect which subsequently gets lowered to AMX instructions.
1 parent ae372a8 commit e9bce17

File tree

10 files changed

+1008
-3
lines changed

10 files changed

+1008
-3
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
[
2+
{
3+
"gemm_bf16_dp2_mlir": {
4+
"bf16_dp2_3x1024_omp_2_mlir": {
5+
"type": "IR-GEN",
6+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
7+
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
8+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16'" ],
9+
"extensions": [ "(avx2)" ]
10+
},
11+
"bf16_dp2_3x1024_omp_4_mlir": {
12+
"type": "IR-GEN",
13+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
14+
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
15+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8'" ],
16+
"extensions": [ "(avx2)" ]
17+
},
18+
"bf16_dp2_3x1024_omp_8_mlir": {
19+
"type": "IR-GEN",
20+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
21+
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
22+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8'" ],
23+
"extensions": [ "(avx2)" ]
24+
},
25+
"bf16_dp2_3x1024_omp_16_mlir": {
26+
"type": "IR-GEN",
27+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
28+
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
29+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8'" ],
30+
"extensions": [ "(avx2)" ]
31+
}
32+
}},
33+
{
34+
"gemm_bf16_dp2_mlir_vector_amx": {
35+
"bf16_dp2_3x1024_omp_2_mlir": {
36+
"type": "IR-GEN",
37+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
38+
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
39+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=32,32,32'" ],
40+
"extensions": ["(amx_bf16)"]
41+
},
42+
"bf16_dp2_3x1024_omp_4_mlir": {
43+
"type": "IR-GEN",
44+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
45+
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
46+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=32,32,32'" ],
47+
"extensions": ["(amx_bf16)"]
48+
},
49+
"bf16_dp2_3x1024_omp_8_mlir": {
50+
"type": "IR-GEN",
51+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
52+
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
53+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=32,32,32'" ],
54+
"extensions": ["(amx_bf16)"]
55+
},
56+
"bf16_dp2_3x1024_omp_16_mlir": {
57+
"type": "IR-GEN",
58+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
59+
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
60+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=32,32,32'" ],
61+
"extensions": ["(amx_bf16)"]
62+
}
63+
}},
64+
{
65+
"mlp_bf16_dp2_mlir": {
66+
"bf16_dp2_3x1024_omp_2_mlir": {
67+
"type": "IR-GEN",
68+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
69+
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
70+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16'" ],
71+
"extensions": [ "(avx2)" ]
72+
},
73+
"bf16_dp2_3x1024_omp_4_mlir": {
74+
"type": "IR-GEN",
75+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
76+
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
77+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8'" ],
78+
"extensions": [ "(avx2)" ]
79+
},
80+
"bf16_dp2_3x1024_omp_8_mlir": {
81+
"type": "IR-GEN",
82+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
83+
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
84+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8'" ],
85+
"extensions": [ "(avx2)" ]
86+
},
87+
"bf16_dp2_3x1024_omp_16_mlir": {
88+
"type": "IR-GEN",
89+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
90+
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
91+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8'" ],
92+
"extensions": [ "(avx2)" ]
93+
}
94+
}},
95+
{
96+
"mlp_bf16_dp2_mlir_vector_amx": {
97+
"bf16_dp2_3x1024_omp_2_mlir": {
98+
"type": "IR-GEN",
99+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
100+
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
101+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=32,32,32'" ],
102+
"extensions": [ "(amx_bf16)" ]
103+
},
104+
"bf16_dp2_3x1024_omp_4_mlir": {
105+
"type": "IR-GEN",
106+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
107+
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
108+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=32,32,32'" ],
109+
"extensions": [ "(amx_bf16)" ]
110+
},
111+
"bf16_dp2_3x1024_omp_8_mlir": {
112+
"type": "IR-GEN",
113+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
114+
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
115+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=32,32,32 '" ],
116+
"extensions": [ "(amx_bf16)" ]
117+
},
118+
"bf16_dp2_3x1024_omp_16_mlir": {
119+
"type": "IR-GEN",
120+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=bf16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
121+
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
122+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=32,32,32'" ],
123+
"extensions": [ "(amx_bf16)" ]
124+
}
125+
}}
126+
]

include/TPP/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ namespace xegpu {
9696
class XeGPUDialect;
9797
} // namespace xegpu
9898

99+
namespace amx {
100+
class AMXDialect;
101+
} // namespace amx
102+
99103
namespace x86vector {
100104
class X86VectorDialect;
101105
} // namespace x86vector

include/TPP/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,18 @@ def VectorContractToFMA : Pass<
8888
"arith::ArithDialect"];
8989
}
9090

91+
def VectorContractToAMX : Pass<
92+
"vector-contract-to-amx"> {
93+
let summary = "Perform vector amx lowering of vector contraction ops";
94+
let dependentDialects = ["memref::MemRefDialect",
95+
"scf::SCFDialect",
96+
"tensor::TensorDialect",
97+
"vector::VectorDialect",
98+
"arith::ArithDialect",
99+
"amx::AMXDialect",
100+
"x86vector::X86VectorDialect"];
101+
}
102+
91103

92104
def BrgemmLinalgTiling : Pass<"tile-brgemm-linalg"> {
93105
let summary = "Tile bregmm matmul and reduction dimension.";

lib/TPP/DefaultPipeline.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,6 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
196196
pm.addPass(createConvertVectorToLLVMPass(options));
197197
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
198198
pm.addPass(createSCFToControlFlowPass());
199-
if (defParallel)
200-
pm.addPass(createConvertOpenMPToLLVMPass());
201199

202200
pm.addNestedPass<func::FuncOp>(createGpuAsyncRegionPass());
203201
pm.addPass(createGpuToLLVMConversionPass());
@@ -214,6 +212,8 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
214212

215213
pm.addPass(createArithToLLVMConversionPass());
216214
pm.addPass(createConvertControlFlowToLLVMPass());
215+
if (defParallel)
216+
pm.addPass(createConvertOpenMPToLLVMPass());
217217
pm.addPass(createUBToLLVMConversionPass());
218218
pm.addPass(createCanonicalizerPass());
219219
pm.addPass(createCSEPass());

lib/TPP/PassBundles/VectorToKernel.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "TPP/Transforms/Utils/VNNIUtils.h"
910
#include "mlir/Dialect/Func/IR/FuncOps.h"
1011
#include "mlir/Dialect/SCF/IR/SCF.h"
1112
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1213
#include "mlir/IR/BuiltinOps.h"
1314
#include "mlir/Pass/Pass.h"
1415
#include "mlir/Pass/PassManager.h"
15-
#include "llvm/Support/Debug.h"
1616
#include "mlir/Transforms/Passes.h"
17+
#include "llvm/Support/Debug.h"
1718

1819
#include "TPP/PassBundles.h"
1920
#include "TPP/PassUtils.h"
@@ -49,8 +50,14 @@ struct VectorToKernel : public tpp::impl::VectorToKernelBase<VectorToKernel>,
4950

5051
private:
5152
void constructPipeline() override {
53+
// TODO: Pass ordering based on target architecture starting from AMX ->
54+
// avx512 -> avx2 to subset needs to be improved by moving out some logic of
55+
// Bf16DotProduct related to iterarg creation and let hoistvectorTransfer
56+
// pass address it.
5257
pm.addNestedPass<func::FuncOp>(createBF16DotProduct());
5358
pm.addNestedPass<func::FuncOp>(createHoistVectorTransfers());
59+
if (vnni::utils::hasAMX())
60+
pm.addNestedPass<func::FuncOp>(createVectorContractToAMX());
5461
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
5562
pm.addNestedPass<func::FuncOp>(createVectorContractToFMA());
5663
}

lib/TPP/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ add_mlir_library(TPPTransforms
3131
HoistVectorTransfers.cpp
3232
VectorContractToFMA.cpp
3333
VectorContractToBF16DotProduct.cpp
34+
VectorContractToAMX.cpp
3435

3536
ADDITIONAL_HEADER_DIRS
3637
${PROJECT_SOURCE_DIR}/include/TPP

0 commit comments

Comments
 (0)