Skip to content

Commit 7df9fa7

Browse files
author
Haixin Huang
authored
[Transform] Introduce passes for transforming from linalg to microkernel dialect (#266)
Contains passes that lower linalg.batch_reduce_matmul & linalgx.batch_reduce_matmul_vnni into microkernel dialect: * ConvertLinalgToMicrokernel: convert linalg Op into frontend microkernel Op * ExpandMicrokernel: expand frontend microkernel Op into detailed microkernel Ops * ConvertMicrokernelToDnnlFunc: convert microkernel Ops into dnnl runtime invocation and related tests.
1 parent 048c93c commit 7df9fa7

File tree

25 files changed

+2120
-76
lines changed

25 files changed

+2120
-76
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ if(GC_ENABLE_LEGACY)
5555
add_subdirectory(legacy/core)
5656
endif()
5757

58-
5958
if (GC_ENABLE_IMEX)
6059
# normalize the value for lit config
6160
set(GC_ENABLE_IMEX ON)

include/gc/Dialect/Microkernel/MicrokernelOps.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
#ifndef GC_DIALECTS_MICROKERNELOPS_H
1010
#define GC_DIALECTS_MICROKERNELOPS_H
1111

12+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1213
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1314
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
1415
#include "mlir/Dialect/SCF/IR/SCF.h"
1516
#include "mlir/IR/BuiltinTypes.h"
1617
#include "mlir/IR/Dialect.h"
1718
#include "mlir/IR/OpDefinition.h"
19+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
1820
#include "mlir/Interfaces/SideEffectInterfaces.h"
1921

2022
#include "gc/Dialect/Microkernel/MicrokernelDialect.h"

include/gc/Dialect/Microkernel/MicrokernelOps.td

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,79 @@
1111

1212
include "MicrokernelDialect.td"
1313
include "gc/Dialect/Microkernel/MicrokernelEnum.td"
14+
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
15+
include "mlir/Interfaces/DestinationStyleOpInterface.td"
1416
include "mlir/Interfaces/SideEffectInterfaces.td"
1517

18+
class StaticTensorRankOf<list<Type> allowedTypes, list<int> ranks> :
19+
Type<And<[TensorOf<allowedTypes>.predicate,
20+
HasAnyRankOfPred<ranks>, HasStaticShapePred]>,
21+
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " static " #
22+
TensorOf<allowedTypes>.summary, "::mlir::TensorType">;
23+
1624
class StaticMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
1725
Type<And<[MemRefOf<allowedTypes>.predicate,
1826
HasAnyRankOfPred<ranks>, HasStaticShapePred]>,
1927
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " static " #
2028
MemRefOf<allowedTypes>.summary, "::mlir::MemRefType">;
2129

30+
def BrgemmTensor : StaticTensorRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>;
31+
32+
def BrgemmTensorOrMemRef : AnyTypeOf<[StaticTensorRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>,
33+
StaticMemRefRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>]>;
34+
35+
def Microkernel_BrgemmOp : Microkernel_Op<"brgemm",
36+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
37+
BufferizableOpInterface,
38+
DestinationStyleOpInterface]> {
39+
let summary = "Abstract Op that execute brgemm kernel on tensors.";
40+
let description = [{
41+
The operation has the following arguments:
42+
1) Tensors or MemRefs of operand A/B;
43+
2) The batch dims and leading dims of operand A/B;
44+
And has the following outputs:
45+
1) Tensor of operand C;
46+
}];
47+
48+
let arguments = (ins Variadic<BrgemmTensorOrMemRef>:$inputs,
49+
BrgemmTensorOrMemRef:$init,
50+
ConfinedAttr<DenseI64ArrayAttr,
51+
[DenseArrayNonNegative<DenseI64ArrayAttr>]>:$batchDims,
52+
ConfinedAttr<DenseI64ArrayAttr,
53+
[DenseArrayNonNegative<DenseI64ArrayAttr>]>:$leadingDims,
54+
TypedArrayAttrBase<Microkernel_BrgemmFlags, "brgemm flags">:$flags);
55+
let results = (outs Variadic<BrgemmTensor>:$output);
56+
57+
let extraClassDeclaration = [{
58+
Value getOperandA() { return getInputs()[0]; }
59+
Value getOperandB() { return getInputs()[1]; }
60+
Value getOperandC() { return getInit(); }
61+
62+
int64_t getBatchDimA() { return getBatchDims()[0]; }
63+
int64_t getLeadingDimA() { return getLeadingDims()[0]; }
64+
65+
int64_t getBatchDimB() { return getBatchDims()[1]; }
66+
int64_t getLeadingDimB() { return getLeadingDims()[1]; }
67+
68+
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
69+
70+
bool bufferizesToMemoryRead(OpOperand &,
71+
const bufferization::AnalysisState &);
72+
bool bufferizesToMemoryWrite(OpOperand &,
73+
const bufferization::AnalysisState &);
74+
bool bufferizesToElementwiseAccess(const bufferization::AnalysisState &,
75+
ArrayRef<OpOperand *>);
76+
bufferization::AliasingValueList getAliasingValues(OpOperand &opOperand,
77+
const bufferization::AnalysisState &state);
78+
LogicalResult bufferize(RewriterBase &,
79+
const bufferization::BufferizationOptions &);
80+
}];
81+
82+
let hasVerifier = 1;
83+
let hasCustomAssemblyFormat = 1;
84+
let hasFolder = 1;
85+
}
86+
2287
def Microkernel_BrgemmDispatchOp : Microkernel_Op<"brgemm.dispatch", [Pure]> {
2388
let summary = "JIT the brgemm microkernel given the parameters";
2489
let description = [{
@@ -80,7 +145,7 @@ def Microkernel_BrgemmEpilogueOp : Microkernel_Op<"brgemm.epilogue"> {
80145
*/
81146
def BrgemmMemRefOrI64 : AnyTypeOf<[StaticMemRefRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>, I64]>;
82147

83-
def Microkernel_BrgemmOp : Microkernel_Op<"brgemm"> {
148+
def Microkernel_BrgemmExecuteOp : Microkernel_Op<"brgemm.execute"> {
84149
let summary = "execute the JITed brgemm kernel.";
85150
let description = [{
86151
The operation has the following arguments:

include/gc/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@ mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix GraphCompiler)
1111
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix GraphCompiler)
1212
add_public_tablegen_target(GraphCompilerPassIncGen)
1313
add_mlir_doc(Passes GraphCompilerPasses ./ -gen-pass-doc)
14+
15+
add_subdirectory(Microkernel)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//===-- BrgemmRuntimeUtils.h - Utils for Brgemm Runtime ---------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef GC_MICROKERNEL_BRGEMMRUNTIMEUTILS_H
10+
#define GC_MICROKERNEL_BRGEMMRUNTIMEUTILS_H
11+
12+
#include "mlir/IR/Attributes.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include "oneapi/dnnl/dnnl_types.h"
15+
16+
namespace mlir::microkernel {
17+
18+
// these strings contain symbols for BRGEMM interfaces used in mlir pass
19+
static const std::string DNNL_BRGEMM_DISPATCH_NAME = "dnnl_brgemm_dispatch";
20+
static const std::string DNNL_BRGEMM_TILECFG_NAME = "dnnl_brgemm_tileconfig";
21+
static const std::string DNNL_BRGEMM_TILERELEASE_NAME =
22+
"dnnl_brgemm_tilerelease";
23+
static const std::string DNNL_BRGEMM_EXECUTE_NAME = "dnnl_brgemm_execute";
24+
25+
static inline int64_t getDnnlDataTypeVal(RewriterBase &rewriter,
26+
Attribute attr) {
27+
auto context = rewriter.getContext();
28+
auto tattr = dyn_cast_or_null<TypeAttr>(attr);
29+
assert(tattr);
30+
if (tattr == TypeAttr::get(FloatType::getF32(context))) {
31+
return static_cast<int64_t>(dnnl_f32);
32+
} else if (tattr == TypeAttr::get(FloatType::getF64(context))) {
33+
return static_cast<int64_t>(dnnl_f64);
34+
} else if (tattr == TypeAttr::get(FloatType::getBF16(context))) {
35+
return static_cast<int64_t>(dnnl_bf16);
36+
} else if (tattr == TypeAttr::get(FloatType::getF16(context))) {
37+
return static_cast<int64_t>(dnnl_f16);
38+
} else if (tattr == TypeAttr::get(
39+
IntegerType::get(context, 32, IntegerType::Signed))) {
40+
return static_cast<int64_t>(dnnl_s32);
41+
} else if (tattr ==
42+
TypeAttr::get(IntegerType::get(context, 8, IntegerType::Signed))) {
43+
return static_cast<int64_t>(dnnl_s8);
44+
} else if (tattr == TypeAttr::get(IntegerType::get(context, 8,
45+
IntegerType::Unsigned))) {
46+
return static_cast<int64_t>(dnnl_u8);
47+
}
48+
return static_cast<int64_t>(dnnl_data_type_undef);
49+
}
50+
51+
}; // namespace mlir::microkernel
52+
53+
#endif // GC_MICROKERNEL_BRGEMMRUNTIMEUTILS_H
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS MicrokernelPasses.td)
2+
mlir_tablegen(MicrokernelPasses.h.inc --gen-pass-decls -name Microkernel)
3+
mlir_tablegen(MicrokernelPasses.capi.h.inc -gen-pass-capi-header --prefix Microkernel)
4+
mlir_tablegen(MicrokernelPasses.capi.cpp.inc -gen-pass-capi-impl --prefix Microkernel)
5+
add_public_tablegen_target(MLIRMicrokernelPassesIncGen)
6+
add_mlir_doc(MicrokernelPasses GraphCompilerMicrokernelPasses ./ -gen-pass-doc)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- MicrokernelPasses.h - Graph Compiler microkerenl passes --*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef GC_MICROKERNELPASSES_H
10+
#define GC_MICROKERNELPASSES_H
11+
12+
#include "gc/Dialect/Linalgx/LinalgxDialect.h"
13+
#include "gc/Dialect/Microkernel/MicrokernelDialect.h"
14+
#include "gc/Dialect/Microkernel/MicrokernelOps.h"
15+
#include "mlir/Pass/Pass.h"
16+
#include <memory>
17+
18+
namespace mlir {
19+
namespace microkernel {
20+
#define GEN_PASS_DECL
21+
#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc"
22+
23+
#define GEN_PASS_REGISTRATION
24+
#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc"
25+
} // namespace microkernel
26+
} // namespace mlir
27+
28+
#endif // GC_MICROKERNELPASSES_H
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
//===-- MicrokernelPasses.td - microkernel passes ----------*- tablegen -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef GC_DIALECT_MICROKERNELPASSES
10+
#define GC_DIALECT_MICROKERNELPASSES
11+
12+
include "mlir/Pass/PassBase.td"
13+
14+
def ConvertLinalgToMicrokernel: Pass<"convert-linalg-to-microkernel", "::mlir::func::FuncOp"> {
15+
let summary = "Lower eligible linalg ops to microkernels";
16+
let description = [{
17+
Convert eligible linalg ops to microkernel dialects based on pattern matching.
18+
For example:
19+
```
20+
scf.forall {
21+
linalg.fill ins(...) outs(...) -> tensor<...>
22+
linalg.batch_reduce_matmul ins(...) outs(...) -> tensor<...>
23+
}
24+
```
25+
Will be changed into
26+
```
27+
scf.forall {
28+
linalg.fill ins(...) outs(...) -> tensor<...>
29+
microkernel.brgemm ins(...) outs(...) -> tensor<...>
30+
}
31+
```
32+
}];
33+
let dependentDialects = ["func::FuncDialect",
34+
"tensor::TensorDialect",
35+
"memref::MemRefDialect",
36+
"linalg::LinalgDialect",
37+
"linalgx::LinalgxDialect",
38+
"microkernel::MicrokernelDialect"];
39+
}
40+
41+
def ExpandMicrokernel: Pass<"expand-microkernel", "::mlir::func::FuncOp"> {
42+
let summary = "Expand abstract microkernels into detailed execution phases";
43+
let description = [{
44+
Expand abstract microkernels into detailed execution phases
45+
For example:
46+
```
47+
scf.forall {
48+
linalg.fill ins(...) outs(...) -> tensor<...>
49+
microkernel.brgemm ins(...) outs(...) -> tensor<...>
50+
}
51+
```
52+
Will be changed into
53+
```
54+
scf.forall {
55+
linalg.fill ins(...) outs(...) -> tensor<...>
56+
%0 = microkernel.brgemm.dispatch(...)
57+
microkernel.brgemm.prologue(%0)
58+
microkernel.brgemm.execute(%0, ...)
59+
microkernel.brgemm.epilogue(%0)
60+
}
61+
```
62+
}];
63+
let dependentDialects = ["func::FuncDialect",
64+
"memref::MemRefDialect",
65+
"microkernel::MicrokernelDialect"];
66+
}
67+
68+
def ConvertMicrokernelToDnnlFunc: Pass<"convert-microkernel-to-dnnl-func", "::mlir::ModuleOp"> {
69+
let summary = "Lower microkernel dialects to dnnl func call";
70+
let description = [{
71+
Convert microkernel dialects to runtime function call to oneDNN library.
72+
}];
73+
let dependentDialects = ["func::FuncDialect",
74+
"memref::MemRefDialect",
75+
"LLVM::LLVMDialect",
76+
"microkernel::MicrokernelDialect"];
77+
}
78+
79+
#endif // GC_DIALECT_MICROKERNELPASSES

include/gc/Transforms/Utils/StructuredOpMatcher.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,16 @@ template <typename T> struct EqualsTo {
217217
};
218218
template <typename T> EqualsTo(T) -> EqualsTo<T>;
219219

220+
// Callable object to check if the input is greater than or equal to specified
221+
// `value`.
222+
struct GreaterThanOrEqualTo {
223+
GreaterThanOrEqualTo() = delete;
224+
explicit GreaterThanOrEqualTo(size_t value) : value(value){};
225+
const size_t value;
226+
227+
bool operator()(size_t value) const { return value >= this->value; }
228+
};
229+
220230
// Callable object to validate number of init operands for `op`.
221231
struct NumDpsInits {
222232
NumDpsInits() = delete;

include/gc/Transforms/Utils/ValueUtils.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ FailureOr<SmallVector<int64_t>> getStaticStrides(Value val);
2727

2828
// Return the offset and ptr for `val`. Assert if `val`
2929
// is not a memref.
30-
std::pair<Value, Value> getPtrAndOffset(OpBuilder &builder, Value val,
31-
Location loc);
30+
std::pair<Value, Value> getPtrAndOffset(OpBuilder &builder, Value operand);
3231

3332
} // namespace utils
3433
} // namespace mlir

0 commit comments

Comments
 (0)