Skip to content

Commit 75b113b

Browse files
authored
Support optimal vector lowering for avx2 target feature. (#1048)
Introduce optimal vector lowering of matmul microkernel for avx2 in vectorContractToFMA pass using register blocking for optimal use of available registers without any spill.
1 parent d8755d3 commit 75b113b

File tree

12 files changed

+674
-85
lines changed

12 files changed

+674
-85
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
[
2+
{
3+
"gemm_fp32_24x64_avx2": {
4+
"fp32_3x1024_omp_2_mlir": {
5+
"type": "IR-GEN",
6+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
7+
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
8+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16'" ],
9+
"extensions": [ "(avx2)" ]
10+
},
11+
"fp32_3x1024_omp_4_mlir": {
12+
"type": "IR-GEN",
13+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
14+
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
15+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8'" ],
16+
"extensions": [ "(avx2)" ]
17+
},
18+
"fp32_3x1024_omp_8_mlir": {
19+
"type": "IR-GEN",
20+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
21+
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
22+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8'" ],
23+
"extensions": [ "(avx2)" ]
24+
},
25+
"fp32_3x1024_omp_16_mlir": {
26+
"type": "IR-GEN",
27+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
28+
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
29+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8'" ],
30+
"extensions": [ "(avx2)" ]
31+
}
32+
}},
33+
{
34+
"gemm_fp32_mlir_vector_kernel_24x64_avx2": {
35+
"fp32_3x1024_omp_2_mlir": {
36+
"type": "IR-GEN",
37+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
38+
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
39+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --target-feature=avx2 --vector-to-kernels --registerBlocking=3,32,1 '" ],
40+
"extensions": [ "avx2" ]
41+
},
42+
"fp32_3x1024_omp_4_mlir": {
43+
"type": "IR-GEN",
44+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
45+
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
46+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --target-feature=avx2 --vector-to-kernels --registerBlocking=3,32,1 '" ],
47+
"extensions": [ "avx2" ]
48+
},
49+
"fp32_3x1024_omp_8_mlir": {
50+
"type": "IR-GEN",
51+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
52+
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
53+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --target-feature=avx2 --vector-to-kernels --registerBlocking=3,32,1 '" ],
54+
"extensions": [ "avx2" ]
55+
},
56+
"fp32_3x1024_omp_16_mlir": {
57+
"type": "IR-GEN",
58+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
59+
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
60+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --target-feature=avx2 --vector-to-kernels --registerBlocking=3,32,1 '" ],
61+
"extensions": [ "avx2" ]
62+
}
63+
}},
64+
{
65+
"mlp_fp32_24x64_avx2": {
66+
"fp32_3x1024_omp_2_mlir": {
67+
"type": "IR-GEN",
68+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
69+
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
70+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16'" ],
71+
"extensions": [ "(avx2)" ]
72+
},
73+
"fp32_3x1024_omp_4_mlir": {
74+
"type": "IR-GEN",
75+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
76+
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
77+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8'" ],
78+
"extensions": [ "(avx2)" ]
79+
},
80+
"fp32_3x1024_omp_8_mlir": {
81+
"type": "IR-GEN",
82+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
83+
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
84+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8'" ],
85+
"extensions": [ "(avx2)" ]
86+
},
87+
"fp32_3x1024_omp_16_mlir": {
88+
"type": "IR-GEN",
89+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
90+
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
91+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8'" ],
92+
"extensions": [ "(avx2)" ]
93+
}
94+
}},
95+
{
96+
"mlp_fp32_mlir_vector_kernel_24x64_avx2": {
97+
"fp32_3x1024_omp_2_mlir": {
98+
"type": "IR-GEN",
99+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
100+
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
101+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=3,32,1 '" ],
102+
"extensions": [ "avx2" ]
103+
},
104+
"fp32_3x1024_omp_4_mlir": {
105+
"type": "IR-GEN",
106+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
107+
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
108+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=3,32,1 '" ],
109+
"extensions": [ "avx2" ]
110+
},
111+
"fp32_3x1024_omp_8_mlir": {
112+
"type": "IR-GEN",
113+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
114+
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
115+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=3,32,1 '" ],
116+
"extensions": [ "avx2" ]
117+
},
118+
"fp32_3x1024_omp_16_mlir": {
119+
"type": "IR-GEN",
120+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=96 --layers=1536,1536 --tiles=24,64,4" ],
121+
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
122+
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=3,32,1 '" ],
123+
"extensions": [ "avx2" ]
124+
}
125+
}}
126+
]

include/TPP/Transforms/Utils/VNNIUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ enum class VnniOperandRank {
3434
BRGEMM_INS = 4,
3535
BRGEMM_OUTS = 3
3636
};
37+
// Returns True if the current architecture supports AVX2 instructions.
38+
bool hasAVX2();
39+
40+
// Returns True if the current architecture supports AVX512 instructions.
41+
bool hasAVX512();
3742

3843
// Returns True if the current architecture supports AMX instructions.
3944
bool hasAMX();

lib/TPP/Transforms/Utils/VNNIUtils.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "TPP/Transforms/Utils/VNNIUtils.h"
1010
#include "TPP/Transforms/Utils/DLTIUtils.h"
1111

12+
#include "libxsmm_cpuid.h"
1213
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1314
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
1415
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -28,6 +29,18 @@ bool hasAMX() {
2829
(libxsmm_get_target_archid() < LIBXSMM_X86_ALLFEAT);
2930
}
3031

32+
// Returns True if the current architecture supports AMX instructions.
33+
bool hasAVX2() {
34+
return (libxsmm_get_target_archid() >= LIBXSMM_X86_AVX2) &&
35+
(libxsmm_get_target_archid() < LIBXSMM_X86_ALLFEAT);
36+
}
37+
38+
// Returns True if the current architecture supports AMX instructions.
39+
bool hasAVX512() {
40+
return (libxsmm_get_target_archid() >= LIBXSMM_X86_AVX512_SKX) &&
41+
(libxsmm_get_target_archid() < LIBXSMM_X86_ALLFEAT);
42+
}
43+
3144
unsigned getVnniBlockingFactor(Type type, Operation *op) {
3245
unsigned blockingFactor = 0;
3346

lib/TPP/Transforms/VectorContractToFMA.cpp

Lines changed: 102 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111

1212
#include "TPP/Passes.h"
1313
#include "TPP/Transforms/Transforms.h"
14+
#include "TPP/Transforms/Utils/VNNIUtils.h"
1415
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1516
#include "mlir/Dialect/SCF/IR/SCF.h"
1617
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1718
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1819
#include "mlir/Pass/Pass.h"
1920
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
#include "llvm/ADT/StringRef.h"
22+
#include "llvm/Support/raw_ostream.h"
2023

2124
#define DEBUG_TYPE "vector-contract-to-fma"
2225

@@ -31,6 +34,23 @@ using namespace mlir;
3134
using namespace mlir::tpp;
3235

3336
namespace {
37+
38+
/// Returns the target vector length based on target features avx2/avx512 for
39+
/// FP32 data type.
40+
static unsigned getTargetVectorLengthForFP32(llvm::StringRef targetFeatureStr) {
41+
unsigned vecElemTypeSizeInBits = 32;
42+
unsigned vecRegSizeInBits = StringSwitch<unsigned>(targetFeatureStr)
43+
.Case("avx2", 256)
44+
.Case("avx512", 512)
45+
.Default(0);
46+
if (vecRegSizeInBits > 0)
47+
return vecRegSizeInBits / vecElemTypeSizeInBits;
48+
49+
vecRegSizeInBits = vnni::utils::hasAVX512() ? 512
50+
: vnni::utils::hasAVX2() ? 256
51+
: 0;
52+
return vecRegSizeInBits / vecElemTypeSizeInBits;
53+
}
3454
/// Returns true if the \p map is transposed.
3555
static bool isTransposed(AffineMap map) {
3656
auto results = map.getResults();
@@ -93,8 +113,12 @@ struct VectorContractToFMA
93113
struct VectorContractToFMAPattern
94114
: public OpRewritePattern<vector::ContractionOp> {
95115
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
96-
VectorContractToFMAPattern(MLIRContext *context, TransformationContext &ctx)
97-
: OpRewritePattern<vector::ContractionOp>(context), ctx(ctx) {}
116+
117+
VectorContractToFMAPattern(MLIRContext *context,
118+
VectorContractToFMAOptions options,
119+
TransformationContext &ctx)
120+
: OpRewritePattern<vector::ContractionOp>(context), options(options),
121+
ctx(ctx) {}
98122

99123
LogicalResult matchAndRewrite(vector::ContractionOp op,
100124
PatternRewriter &rewriter) const override {
@@ -135,8 +159,6 @@ struct VectorContractToFMAPattern
135159
iteratorTypes[outerDimIndex + 2] != vector::IteratorType::reduction)
136160
return rewriter.notifyMatchFailure(op, "Not a gemm");
137161

138-
SmallVector<Value, 4> results;
139-
140162
auto lhs = op.getLhs();
141163
auto rhs = op.getRhs();
142164
auto acc = op.getAcc();
@@ -246,6 +268,12 @@ struct VectorContractToFMAPattern
246268
if (K != 1)
247269
return failure();
248270

271+
unsigned vecLen = getTargetVectorLengthForFP32(options.targetFeature);
272+
if (vecLen == 0)
273+
return failure();
274+
275+
SmallVector<Value, 12> results;
276+
SmallVector<Value, 12> argResults;
249277
auto accSubview = accDefiningOp.getBase();
250278
Location loc = op.getLoc();
251279

@@ -273,12 +301,15 @@ struct VectorContractToFMAPattern
273301
subview_2_splits.push_back(split);
274302
}
275303

276-
// Intialize each accumulator with a vector of size N
304+
// Intialize each accumulator with a vector of size vecLen
277305
SmallVector<Value, 4> initAccs;
278306
for (auto subview : subview_2_splits) {
279-
auto acc = rewriter.create<vector::LoadOp>(
280-
loc, VectorType::get({N}, elementType), subview, ValueRange{c0, c0});
281-
initAccs.push_back(acc);
307+
for (unsigned j = 0; j < N; j += vecLen) {
308+
auto acc = rewriter.create<vector::LoadOp>(
309+
loc, VectorType::get({vecLen}, elementType), subview,
310+
ValueRange{c0, rewriter.create<arith::ConstantIndexOp>(loc, j)});
311+
initAccs.push_back(acc);
312+
}
282313
}
283314

284315
// Create new outer loop with M different accumulators.
@@ -314,7 +345,7 @@ struct VectorContractToFMAPattern
314345
innerBuilder.create<arith::ConstantIndexOp>(loc, i),
315346
c0});
316347
auto bcast = innerBuilder.create<vector::BroadcastOp>(
317-
loc, VectorType::get({N}, elem.getType()), elem);
348+
loc, VectorType::get({vecLen}, elem.getType()), elem);
318349
broadcasts.push_back(bcast);
319350
}
320351

@@ -325,18 +356,58 @@ struct VectorContractToFMAPattern
325356
rhsMapping.map(
326357
rhsDefiningOp.getBase().getDefiningOp()->getOperand(2),
327358
innerIv);
359+
360+
// Create Mx(N/vecLen) different FMAs using broadcasts and
361+
// current accumulator values.
328362
auto rhsClone = innerBuilder.clone(
329363
*rhsDefiningOp.getBase().getDefiningOp(), rhsMapping);
330-
auto rowVec = innerBuilder.create<vector::LoadOp>(
331-
loc, VectorType::get({N}, elementType),
332-
rhsClone->getResult(0), ValueRange{c0, c0, c0});
333-
334-
// Create M different FMAs using broadcasts and current
335-
// accumulator values.
336-
for (int i = 0; i < M; i++) {
337-
auto fma = innerBuilder.create<vector::FMAOp>(
338-
loc, broadcasts[i], rowVec, innerIterArgs[i]);
339-
results.push_back(fma);
364+
if (vecLen == 8) {
365+
for (unsigned j = 0; j < N; j += vecLen) {
366+
auto rowVec = innerBuilder.create<vector::LoadOp>(
367+
loc, VectorType::get({vecLen}, elementType),
368+
rhsClone->getResult(0),
369+
ValueRange{c0, c0,
370+
innerBuilder.create<arith::ConstantIndexOp>(
371+
loc, j)});
372+
unsigned iterArgAccessStride = N / vecLen;
373+
unsigned offset = j / vecLen;
374+
for (int i = 0; i < M; i++) {
375+
auto fma = innerBuilder.create<vector::FMAOp>(
376+
loc, broadcasts[i], rowVec,
377+
innerIterArgs[offset + iterArgAccessStride * i]);
378+
argResults.push_back(fma);
379+
}
380+
}
381+
382+
// Perform strided circular copy of elements from argResults
383+
// to results.
384+
unsigned stride = (N / vecLen);
385+
unsigned totalElements = argResults.size();
386+
results.resize(totalElements);
387+
for (unsigned i = 0; i < totalElements; ++i) {
388+
unsigned circularIndex =
389+
(i % stride) * (stride - 1) + (i / stride);
390+
results[i] = argResults[circularIndex];
391+
}
392+
393+
} else {
394+
for (int i = 0; i < M; i++) {
395+
unsigned iterArgAccessStride = (i) * ((N / vecLen));
396+
for (unsigned j = 0; j < N; j += vecLen) {
397+
auto rowVec = innerBuilder.create<vector::LoadOp>(
398+
loc, VectorType::get({vecLen}, elementType),
399+
rhsClone->getResult(0),
400+
ValueRange{
401+
c0, c0,
402+
innerBuilder.create<arith::ConstantIndexOp>(loc,
403+
j)});
404+
unsigned offset = (j / vecLen);
405+
auto fma = innerBuilder.create<vector::FMAOp>(
406+
loc, broadcasts[i], rowVec,
407+
innerIterArgs[offset + iterArgAccessStride]);
408+
results.push_back(fma);
409+
}
410+
}
340411
}
341412

342413
// Yield all M results
@@ -358,9 +429,14 @@ struct VectorContractToFMAPattern
358429
// Store final results back to original locations.
359430
if (writeOp) {
360431
for (int i = 0; i < M; i++) {
361-
rewriter.create<vector::StoreOp>(loc, newOuterForOp.getResult(i),
362-
subview_2_splits[i],
363-
ValueRange{c0, c0});
432+
unsigned iterArgAccessStride = i * (N / vecLen);
433+
for (unsigned j = 0; j < N; j += vecLen) {
434+
unsigned offset = j / vecLen;
435+
rewriter.create<vector::StoreOp>(
436+
loc, newOuterForOp.getResult(offset + iterArgAccessStride),
437+
subview_2_splits[i],
438+
ValueRange{c0, rewriter.create<arith::ConstantIndexOp>(loc, j)});
439+
}
364440
}
365441
}
366442

@@ -372,15 +448,17 @@ struct VectorContractToFMAPattern
372448
}
373449

374450
private:
451+
VectorContractToFMAOptions options;
375452
TransformationContext &ctx;
376453
};
377454

378455
void VectorContractToFMA::runOnOperation() {
456+
VectorContractToFMAOptions options;
379457
auto funcOp = getOperation();
380458
MLIRContext *context = &getContext();
381-
459+
options.targetFeature = targetFeature;
382460
RewritePatternSet patterns(context);
383-
patterns.add<VectorContractToFMAPattern>(context, ctx);
461+
patterns.add<VectorContractToFMAPattern>(context, options, ctx);
384462

385463
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
386464
signalPassFailure();
@@ -389,7 +467,3 @@ void VectorContractToFMA::runOnOperation() {
389467

390468
} // namespace tpp
391469
} // namespace mlir
392-
393-
std::unique_ptr<Pass> createVectorContractToFMA() {
394-
return std::make_unique<VectorContractToFMA>();
395-
}

0 commit comments

Comments
 (0)