Skip to content

Commit b53e46f

Browse files
authored
[mlir][x86vector] Lower vector.contract to FMA or packed type dot-product (#168074)
A `transform` pass to lower `vector.contract` to (a) `vector.fma` for `F32`, (b) `x86vector.avx512.dot` for `BF16`, (c) `x86vector.avx.dot.i8` for `Int8` packed types. The lowering works on condition with `m`, `batch`, `k` dims to be `one` and `vnni` dim should be `2` for `bf16`; `4` for `int8`. **The lowering pattern**: `batch_reduce.matmul` (input) -> register-tiling(M, N) -> Vectorization (to `vector.contract`) -> `unroll` vector.contract (`unit` dims) -> `hoisting` transformation (move `C` loads/store outside batch/k loop) -> apply `licm`, `canonicalization`, and `bufferize`.
1 parent 13a39ea commit b53e46f

File tree

14 files changed

+1647
-0
lines changed

14 files changed

+1647
-0
lines changed

mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
33

44
add_mlir_interface(X86VectorInterfaces)
55
add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)
6+
7+
add_subdirectory(TransformOps)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
2+
mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
3+
mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
4+
add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
10+
#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
11+
12+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
13+
#include "mlir/IR/OpImplementation.h"
14+
15+
//===----------------------------------------------------------------------===//
16+
// X86Vector Transform Operations
17+
//===----------------------------------------------------------------------===//
18+
19+
#define GET_OP_CLASSES
20+
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc"
21+
22+
namespace mlir {
23+
class DialectRegistry;
24+
25+
namespace x86vector {
26+
void registerTransformDialectExtension(DialectRegistry &registry);
27+
28+
} // namespace x86vector
29+
} // namespace mlir
30+
31+
#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//===- X86VectorTransformOps.td - X86Vector transform ops --*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef X86VECTOR_TRANSFORM_OPS
10+
#define X86VECTOR_TRANSFORM_OPS
11+
12+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
13+
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
14+
include "mlir/Interfaces/SideEffectInterfaces.td"
15+
include "mlir/IR/OpBase.td"
16+
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
17+
include "mlir/Dialect/Transform/IR/TransformTypes.td"
18+
include "mlir/IR/RegionKindInterface.td"
19+
20+
def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
21+
"apply_patterns.x86vector.vector_contract_to_fma",
22+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
23+
let description = [{
24+
Collect patterns to lower a F32 type vector.contract operation to a FMA.
25+
}];
26+
27+
let assemblyFormat = "attr-dict";
28+
}
29+
30+
def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
31+
"apply_patterns.x86vector.vector_contract_to_packed_type_dot_product",
32+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
33+
let description = [{
34+
Collect patterns to lower a BF16/Int8 type vector.contract operation
35+
to a BF16/Int8 dot-product.
36+
}];
37+
38+
let assemblyFormat = "attr-dict";
39+
}
40+
41+
42+
#endif // X86VECTOR_TRANSFORM_OPS
43+

mlir/include/mlir/Dialect/X86Vector/Transforms.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,18 @@ struct MaskHelper {
7979
}
8080
};
8181

82+
//===----------------------------------------------------------------------===//
83+
84+
// A set of patterns for specialized lowering of vector contraction
85+
// operation to vector fused multiply and add (FMA) operation.
86+
void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
87+
88+
// A set of patterns for lowering 32-bit packed vector contraction operations
89+
// to their corresponding packed-type dot-product operations, ultimately
90+
// targeting the relevant x86 LLVM intrinsics (e.g., BF16 and Int8).
91+
void populateVectorContractToPackedTypeDotProductPatterns(
92+
RewritePatternSet &patterns);
93+
8294
//===----------------------------------------------------------------------===//
8395
/// Helpers extracted from:
8496
/// - clang/lib/Headers/avxintrin.h
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(IR)
22
add_subdirectory(Transforms)
3+
add_subdirectory(TransformOps)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_dialect_library(MLIRX86VectorTransformOps
2+
X86VectorTransformOps.cpp
3+
4+
DEPENDS
5+
MLIRX86VectorTransformOpsIncGen
6+
7+
LINK_LIBS PUBLIC
8+
MLIRIR
9+
MLIRLLVMCommonConversion
10+
MLIRLLVMDialect
11+
MLIRVectorDialect
12+
MLIRSideEffectInterfaces
13+
MLIRTransformDialect
14+
MLIRTransformDialectUtils
15+
MLIRX86VectorDialect
16+
MLIRX86VectorTransforms
17+
)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//===- X86VectorTransformOps.cpp ------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
10+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
11+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
13+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
14+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
15+
#include "mlir/Dialect/X86Vector/Transforms.h"
16+
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
17+
18+
#include "mlir/IR/OpImplementation.h"
19+
#include "mlir/IR/RegionKindInterface.h"
20+
21+
using namespace mlir;
22+
using namespace mlir::x86vector;
23+
using namespace mlir::transform;
24+
25+
void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
26+
RewritePatternSet &patterns) {
27+
x86vector::populateVectorContractToFMAPatterns(patterns);
28+
}
29+
30+
void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
31+
populatePatterns(RewritePatternSet &patterns) {
32+
x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
33+
}
34+
35+
//===----------------------------------------------------------------------===//
36+
// Transform op registration
37+
//===----------------------------------------------------------------------===//
38+
39+
namespace {
40+
class X86VectorTransformDialectExtension
41+
: public transform::TransformDialectExtension<
42+
X86VectorTransformDialectExtension> {
43+
public:
44+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
45+
X86VectorTransformDialectExtension)
46+
47+
X86VectorTransformDialectExtension() {
48+
declareGeneratedDialect<x86vector::X86VectorDialect>();
49+
declareGeneratedDialect<LLVM::LLVMDialect>();
50+
registerTransformOps<
51+
#define GET_OP_LIST
52+
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
53+
>();
54+
}
55+
};
56+
} // namespace
57+
58+
#define GET_OP_CLASSES
59+
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
60+
61+
void mlir::x86vector::registerTransformDialectExtension(
62+
DialectRegistry &registry) {
63+
registry.addExtensions<X86VectorTransformDialectExtension>();
64+
}

mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
add_mlir_dialect_library(MLIRX86VectorTransforms
22
AVXTranspose.cpp
33
LegalizeForLLVMExport.cpp
4+
VectorContractToFMA.cpp
5+
VectorContractToPackedTypeDotProduct.cpp
46

57
LINK_LIBS PUBLIC
68
MLIRArithDialect
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
//===- VectorContractToFMA.cpp --------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
10+
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
11+
#include "mlir/Dialect/X86Vector/Transforms.h"
12+
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
13+
14+
#include "mlir/IR/BuiltinAttributes.h"
15+
#include "mlir/IR/Dominance.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
18+
#include "mlir/Pass/Pass.h"
19+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20+
21+
using namespace mlir;
22+
using namespace mlir::vector;
23+
using namespace mlir::x86vector;
24+
25+
namespace {
26+
27+
// Implements outer product contraction as a sequence of broadcast and
28+
// FMA operations.
29+
//
30+
// For example - for F32 type:
31+
// ```
32+
// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32>
33+
// ```
34+
// to
35+
// ```
36+
// vector.broadcast %lhs to <16xf32>
37+
// vector.fma vector<16xf32>
38+
// ```
39+
struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
40+
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
41+
42+
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
43+
PatternRewriter &rewriter) const override {
44+
45+
if (contractOp.getKind() != vector::CombiningKind::ADD)
46+
return rewriter.notifyMatchFailure(contractOp,
47+
"Expects add combining kind.");
48+
49+
VectorType lhsTy = contractOp.getLhsType();
50+
if (!lhsTy.getElementType().isF32())
51+
return rewriter.notifyMatchFailure(contractOp,
52+
"Only F32 lowering is supported.");
53+
54+
ArrayRef<int64_t> lhsShape = lhsTy.getShape();
55+
llvm::SmallVector<int64_t> nonUnitDimLhs;
56+
llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
57+
[](int64_t dim) { return dim != 1; });
58+
59+
VectorType rhsTy = contractOp.getRhsType();
60+
ArrayRef<int64_t> rhsShape = rhsTy.getShape();
61+
llvm::SmallVector<int64_t> nonUnitDimRhs;
62+
llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
63+
[](int64_t dim) { return dim != 1; });
64+
65+
if (nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0)
66+
return rewriter.notifyMatchFailure(
67+
contractOp, "Excepts unit dimensions for either LHS or RHS shape.");
68+
69+
if (nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1)
70+
return rewriter.notifyMatchFailure(
71+
contractOp,
72+
"Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
73+
74+
VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
75+
if (!accTy)
76+
return rewriter.notifyMatchFailure(contractOp,
77+
"Accmulator is not a vector type");
78+
79+
if (!accTy.getElementType().isF32())
80+
return rewriter.notifyMatchFailure(contractOp,
81+
"Accmulator should be F32 type.");
82+
83+
ArrayRef<int64_t> accShape = accTy.getShape();
84+
llvm::SmallVector<int64_t> nonUnitDimAcc;
85+
llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
86+
[](int64_t dim) { return dim != 1; });
87+
if (nonUnitDimAcc.size() != 1)
88+
return rewriter.notifyMatchFailure(
89+
contractOp, "A or B dimension should be non-unit.");
90+
91+
// Lowers vector.contract into a broadcast+FMA sequence.
92+
auto loc = contractOp.getLoc();
93+
auto castAcc = vector::ShapeCastOp::create(
94+
rewriter, loc,
95+
VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
96+
contractOp.getAcc());
97+
98+
vector::FMAOp fma;
99+
100+
// Broadcast the unit-dimension LHS or RHS to match the vector length of the
101+
// corresponding non-unit dimension on the other operand. For example,
102+
// if LHS has type vector<1x1xf32> and RHS has type vector<1x16xf32>, we
103+
// broadcast the LHS to vector<1x16xf32>. In the opposite case (non-unit
104+
// dimension on the LHS), we broadcast the RHS instead.
105+
if (nonUnitDimRhs.size() > 0) {
106+
auto castLhs = vector::ShapeCastOp::create(
107+
rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
108+
contractOp.getLhs());
109+
auto castRhs = vector::ShapeCastOp::create(
110+
rewriter, loc,
111+
VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
112+
contractOp.getRhs());
113+
auto broadcastLhs = vector::BroadcastOp::create(
114+
rewriter, loc, castRhs.getResult().getType(), castLhs);
115+
fma =
116+
vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc);
117+
} else {
118+
auto castLhs = vector::ShapeCastOp::create(
119+
rewriter, loc,
120+
VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
121+
contractOp.getLhs());
122+
auto castRhs = vector::ShapeCastOp::create(
123+
rewriter, loc, VectorType::get(1, rhsTy.getElementType()),
124+
contractOp.getRhs());
125+
auto broadcastRhs = vector::BroadcastOp::create(
126+
rewriter, loc, castLhs.getResult().getType(), castRhs);
127+
fma =
128+
vector::FMAOp::create(rewriter, loc, castLhs, broadcastRhs, castAcc);
129+
}
130+
131+
auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
132+
rewriter.replaceOp(contractOp, castFma);
133+
134+
return success();
135+
}
136+
};
137+
138+
} // namespace
139+
140+
void x86vector::populateVectorContractToFMAPatterns(
141+
RewritePatternSet &patterns) {
142+
patterns.add<VectorContractToFMA>(patterns.getContext());
143+
}

0 commit comments

Comments
 (0)