Skip to content

Commit 60d3443

Browse files
committed
Merge branch 'xurui/benchgc_tuner' of https://github.com/xurui1995/graph-compiler into xurui/benchgc_tuner
2 parents 1f237f9 + faf3e76 commit 60d3443

File tree

10 files changed

+486
-62
lines changed

10 files changed

+486
-62
lines changed

lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,24 @@ using read_lock_guard_t = std::shared_lock<std::shared_mutex>;
5353
using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
5454
static std::shared_mutex g_brgemm_lock;
5555

56-
static std::vector<brgemm_desc_t> g_brgemm_desc_list;
57-
static std::vector<brgemm_kernel_t *> g_brgemm_kernel_list;
58-
static std::vector<std::unique_ptr<char[]>> g_brgemm_palette;
56+
struct brgemm_cache_info_t {
57+
brgemm_desc_t desc;
58+
brgemm_kernel_t *kernel;
59+
std::shared_ptr<char[]> palette;
60+
};
61+
62+
static std::vector<brgemm_cache_info_t> g_cache;
5963

6064
// TODO(haixin): use syscall to determine page size?
6165
static constexpr size_t SCRATCH_SIZE = 2 * 4096;
6266
// TODO(haixin): need to use custom thread management for scratch in the future?
6367
static thread_local char scratch[SCRATCH_SIZE] = {0};
6468

69+
static std::unordered_map<int64_t, brgemm_cache_info_t> &get_tl_cache() {
70+
thread_local std::unordered_map<int64_t, brgemm_cache_info_t> tl_cache;
71+
return tl_cache;
72+
}
73+
6574
extern "C" {
6675

6776
int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
@@ -93,33 +102,33 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
93102
brgemm_desc_set_attr(&desc, dnnl_attrs);
94103

95104
// TODO(haixin): Reuse identical palettes across kernels
96-
char *palette_buffer = nullptr;
105+
std::shared_ptr<char[]> palette_buffer;
97106
if (desc.is_tmm) {
98-
palette_buffer = new char[PALETTE_SIZE];
99-
dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer);
107+
palette_buffer.reset(new char[PALETTE_SIZE]);
108+
dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer.get());
100109
assert(status == dnnl::impl::status::success &&
101110
"Failed to initialize palette for BRGEMM");
102111
}
103112

104113
write_lock_guard_t g(g_brgemm_lock);
105-
g_brgemm_desc_list.push_back(desc);
106-
g_brgemm_kernel_list.push_back(kernel);
107-
g_brgemm_palette.emplace_back(palette_buffer);
108-
109-
return g_brgemm_desc_list.size() - 1;
114+
g_cache.push_back(brgemm_cache_info_t{desc, kernel, palette_buffer});
115+
return g_cache.size() - 1;
110116
}
111117

112118
void dnnl_brgemm_tileconfig(int64_t kernel_idx) {
113-
char *palette_buffer = nullptr;
114-
{
119+
assert(kernel_idx >= 0 && "Invalid kernel handler");
120+
auto &tl_cache = get_tl_cache();
121+
auto it = tl_cache.find(kernel_idx);
122+
if (it == tl_cache.end()) {
115123
read_lock_guard_t g(g_brgemm_lock);
116-
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_brgemm_desc_list.size() &&
117-
"Invalid kernel handler");
118-
brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx];
119-
if (!desc.is_tmm) {
120-
return;
121-
}
122-
palette_buffer = g_brgemm_palette[kernel_idx].get();
124+
assert(kernel_idx < (int64_t)g_cache.size() && "Invalid kernel handler");
125+
it = tl_cache.insert({kernel_idx, g_cache[kernel_idx]}).first;
126+
}
127+
brgemm_desc_t &desc = it->second.desc;
128+
char *palette_buffer = it->second.palette.get();
129+
130+
if (!desc.is_tmm) {
131+
return;
123132
}
124133

125134
assert(palette_buffer != nullptr && "Invalid palette for BRGEMM kernel");
@@ -137,24 +146,29 @@ void dnnl_brgemm_tilerelease() {
137146
void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset,
138147
void *B, uint64_t B_offset, void *C, uint64_t C_offset,
139148
int num) {
140-
brgemm_kernel_t *kernel = nullptr;
141-
size_t A_offset_in_bytes;
142-
size_t B_offset_in_bytes;
143-
size_t C_offset_in_bytes;
144-
{
149+
auto &tl_cache = get_tl_cache();
150+
if (tl_cache.find(kernel_idx) == tl_cache.end()) {
145151
read_lock_guard_t g(g_brgemm_lock);
146-
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_brgemm_desc_list.size() &&
152+
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() &&
147153
"Invalid kernel handler");
148-
149-
brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx];
150-
kernel = g_brgemm_kernel_list[kernel_idx];
151-
152-
A_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_a) * A_offset;
153-
B_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_b) * B_offset;
154-
C_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_c) * C_offset;
154+
auto updated_cache =
155+
tl_cache.insert(std::make_pair(kernel_idx, g_cache[kernel_idx]));
156+
assert(updated_cache.second && "insert into thread local cache");
155157
}
158+
auto it = tl_cache.find(kernel_idx);
159+
brgemm_kernel_t *kernel = it->second.kernel;
160+
brgemm_desc_t *desc_ptr = &it->second.desc;
156161

157162
assert(kernel && "Invalid brgemm kernel pointer");
163+
assert(desc_ptr && "Invalid brgemm descriptor pointer");
164+
165+
size_t A_offset_in_bytes =
166+
dnnl::impl::types::data_type_size(desc_ptr->dt_a) * A_offset;
167+
size_t B_offset_in_bytes =
168+
dnnl::impl::types::data_type_size(desc_ptr->dt_b) * B_offset;
169+
size_t C_offset_in_bytes =
170+
dnnl::impl::types::data_type_size(desc_ptr->dt_c) * C_offset;
171+
158172
char *A_arith = (char *)A;
159173
char *B_arith = (char *)B;
160174
char *C_arith = (char *)C;

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ gc_add_mlir_library(GcPasses
1313
MemRefToCPURuntime.cpp
1414
OneDNNGraphToLinalg.cpp
1515
Pipeline.cpp
16+
TileUsingInterfaceX.cpp
1617
IterativeTilingAndFusion.cpp
17-
TilingUsingInterfaceX.cpp
1818
VerifyTargetDescription.cpp
1919
DecomposeAggregatedOps.cpp
2020
DeepTileContractionOp.cpp

lib/gc/Transforms/IterativeTilingAndFusion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include <memory>
3434
#include <unordered_map>
3535

36-
#include "TilingUsingInterfaceX.h"
36+
#include "TileUsingInterfaceX.h"
3737

3838
namespace mlir {
3939
namespace gc {

lib/gc/Transforms/MergeAllocTickBased.cpp

Lines changed: 208 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,65 @@ using namespace special_ticks;
3030
/// and default memory space.
3131
static bool isMemRefTypeOk(MemRefType type) { return type.hasStaticShape(); }
3232

33+
static inline int64_t getSizeInBytes(MemRefType &memType) {
34+
// treat bool (i1) as 1 byte. It may not be true for all targets, but we at
35+
// least have a large enough size for i1
36+
int64_t size = memType.getElementTypeBitWidth() / 8;
37+
size = (size > 0) ? size : 1;
38+
for (auto v : memType.getShape()) {
39+
size *= v;
40+
}
41+
return size;
42+
}
43+
44+
static bool needsHoistOutOfParallelLoop(Operation *op) {
45+
Operation *parent =
46+
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
47+
if (isa_and_nonnull<scf::ForallOp>(parent)) {
48+
// check if the current allocation is between the nested pfor, and use
49+
// inside the inner parallel loop
50+
SmallVector<Operation *, 4> parallelOpInCurBlock;
51+
Block *curBlock = op->getBlock();
52+
for (auto &curOp : curBlock->getOperations()) {
53+
if (isa<scf::ForallOp>(curOp)) {
54+
parallelOpInCurBlock.push_back(&curOp);
55+
}
56+
}
57+
58+
if (parallelOpInCurBlock.empty())
59+
return false;
60+
61+
for (auto *use : op->getUsers()) {
62+
for (auto *parallelOp : parallelOpInCurBlock) {
63+
if (parallelOp->isAncestor(use)) {
64+
return true;
65+
}
66+
}
67+
}
68+
}
69+
70+
return false;
71+
}
72+
73+
static bool isForallLoopBoundStatic(Operation *op) {
74+
auto forallOp = dyn_cast<scf::ForallOp>(op);
75+
if (!forallOp)
76+
return false;
77+
78+
auto lbs = forallOp.getMixedLowerBound();
79+
auto ubs = forallOp.getMixedUpperBound();
80+
auto steps = forallOp.getMixedStep();
81+
auto allConstantValue = [](SmallVector<OpFoldResult> vals) -> bool {
82+
return llvm::all_of(vals, [](OpFoldResult val) {
83+
std::optional<int64_t> const_val = getConstantIntValue(val);
84+
return const_val.has_value();
85+
});
86+
};
87+
88+
return allConstantValue(lbs) && allConstantValue(ubs) &&
89+
allConstantValue(steps);
90+
}
91+
3392
void Tick::update(int64_t tick) {
3493
if (tick == UNTRACEABLE_ACCESS) {
3594
firstAccess = UNTRACEABLE_ACCESS;
@@ -180,28 +239,60 @@ bool TickCollecter::isMergeableAlloc(TickCollecterStates *s, Operation *op,
180239
// trait, and is not scf.for
181240
Operation *TickCollecter::getAllocScope(TickCollecterStates *s,
182241
Operation *op) const {
183-
auto parent = op;
242+
Operation *parent = op;
243+
bool moveToUpperParellelLoop = needsHoistOutOfParallelLoop(op);
244+
184245
for (;;) {
185246
parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
186247
if (!parent) {
187248
return nullptr;
188249
}
189-
if (!isa<scf::ForOp>(parent)) {
190-
return parent;
191-
}
250+
251+
if (isa<scf::ForOp>(parent))
252+
continue;
253+
254+
if (isa<scf::ForallOp>(parent) &&
255+
(moveToUpperParellelLoop && isForallLoopBoundStatic(parent)))
256+
continue;
257+
258+
return parent;
192259
}
193260
}
194261

195262
FailureOr<size_t> TickCollecter::getAllocSize(TickCollecterStates *s,
196263
Operation *op) const {
197264
auto refType = cast<MemRefType>(op->getResultTypes().front());
198-
int64_t size = refType.getElementTypeBitWidth() / 8;
199-
// treat bool (i1) as 1 byte. It may not be true for all targets, but we at
200-
// least have a large enough size for i1
201-
size = (size != 0) ? size : 1;
202-
for (auto v : refType.getShape()) {
203-
size *= v;
265+
266+
// Get the total number of threads from the outermost to the current level of
267+
// the parallel loop that the allocation located in.
268+
int64_t numThreads = 1;
269+
if (needsHoistOutOfParallelLoop(op)) {
270+
Operation *parent =
271+
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
272+
while (auto forallOp = dyn_cast<scf::ForallOp>(parent)) {
273+
if (!isForallLoopBoundStatic(forallOp))
274+
break;
275+
276+
OpBuilder builder{forallOp->getContext()};
277+
std::optional<int64_t> numIterations;
278+
for (auto [lb, ub, step] : llvm::zip(forallOp.getLowerBound(builder),
279+
forallOp.getUpperBound(builder),
280+
forallOp.getStep(builder))) {
281+
numIterations = constantTripCount(lb, ub, step);
282+
if (numIterations.has_value()) {
283+
numThreads *= numIterations.value();
284+
} else {
285+
return op->emitError("Expecting static loop range!");
286+
}
287+
}
288+
289+
parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
290+
}
204291
}
292+
assert(numThreads > 0);
293+
294+
int64_t size = getSizeInBytes(refType);
295+
size *= numThreads;
205296
if (size > 0) {
206297
return static_cast<size_t>(size);
207298
}
@@ -391,11 +482,113 @@ Value MergeAllocDefaultMutator::buildView(OpBuilder &builder, Block *scope,
391482
Value mergedAlloc,
392483
int64_t byteOffset) const {
393484
builder.setInsertionPoint(origAllocOp);
394-
auto byteShift =
395-
builder.create<arith::ConstantIndexOp>(origAllocOp->getLoc(), byteOffset);
396-
return builder.create<memref::ViewOp>(origAllocOp->getLoc(),
397-
origAllocOp->getResultTypes().front(),
398-
mergedAlloc, byteShift, ValueRange{});
485+
auto loc = origAllocOp->getLoc();
486+
auto byteShift = builder.create<arith::ConstantIndexOp>(loc, byteOffset);
487+
488+
bool moveToUpperParellelLoop = needsHoistOutOfParallelLoop(origAllocOp);
489+
Operation *parent =
490+
origAllocOp->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
491+
if (!moveToUpperParellelLoop || !parent || !isa<scf::ForallOp>(parent))
492+
return builder.create<memref::ViewOp>(loc,
493+
origAllocOp->getResultTypes().front(),
494+
mergedAlloc, byteShift, ValueRange{});
495+
496+
// get the aggregated inductorVar
497+
Value inductVar;
498+
bool isOuterMostLoop = true;
499+
int64_t innerLoopUpperBound = 1;
500+
while (parent) {
501+
if (auto forallOp = dyn_cast<scf::ForallOp>(parent)) {
502+
if (isForallLoopBoundStatic(forallOp)) {
503+
SmallVector<Value> ubs = forallOp.getUpperBound(builder);
504+
SmallVector<Value> lbs = forallOp.getLowerBound(builder);
505+
SmallVector<Value> steps = forallOp.getStep(builder);
506+
SmallVector<Value> inductionVars = forallOp.getInductionVars();
507+
508+
auto getCurrentVar = [&loc, &builder](Value var, Value lb,
509+
Value step) -> Value {
510+
if (!isConstantIntValue(lb, 0))
511+
var = builder.create<arith::SubIOp>(loc, var, lb);
512+
513+
if (!isConstantIntValue(step, 1))
514+
var = builder.create<arith::DivSIOp>(loc, var, step);
515+
return var;
516+
};
517+
518+
auto getAggregatedVar =
519+
[&loc, &builder, &getCurrentVar](
520+
const SmallVector<Value> &_lbs, const SmallVector<Value> &_ubs,
521+
const SmallVector<Value> &_steps,
522+
const SmallVector<Value> &_inductVars) -> Value {
523+
Value var;
524+
if (_ubs.size() == 1) {
525+
var = getCurrentVar(_inductVars[0], _lbs[0], _steps[0]);
526+
return var;
527+
} else {
528+
bool isFirstLoop = true;
529+
for (auto [lb, ub, step, inductVar] :
530+
llvm::zip(_lbs, _ubs, _steps, _inductVars)) {
531+
if (isFirstLoop) {
532+
var = getCurrentVar(inductVar, lb, step);
533+
isFirstLoop = false;
534+
} else {
535+
Value cur_var = getCurrentVar(inductVar, lb, step);
536+
std::optional<int64_t> bound = constantTripCount(lb, ub, step);
537+
assert(bound.has_value());
538+
Value boundVal =
539+
builder.create<arith::ConstantIndexOp>(loc, bound.value());
540+
Value tmpVal =
541+
builder.create<arith::MulIOp>(loc, var, boundVal);
542+
var = builder.create<arith::AddIOp>(loc, tmpVal, cur_var);
543+
}
544+
}
545+
return var;
546+
}
547+
};
548+
549+
if (isOuterMostLoop) {
550+
inductVar = getAggregatedVar(lbs, ubs, steps, inductionVars);
551+
isOuterMostLoop = false;
552+
} else {
553+
Value currentVar = getAggregatedVar(lbs, ubs, steps, inductionVars);
554+
555+
Value innerLoopBoundVal =
556+
builder.create<arith::ConstantIndexOp>(loc, innerLoopUpperBound);
557+
Value intermediateVal =
558+
builder.create<arith::MulIOp>(loc, currentVar, innerLoopBoundVal);
559+
inductVar =
560+
builder.create<arith::AddIOp>(loc, inductVar, intermediateVal);
561+
}
562+
// get aggregated loop bound
563+
for (auto [lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
564+
std::optional<int64_t> cur_bound = constantTripCount(lb, ub, step);
565+
assert(cur_bound.has_value());
566+
innerLoopUpperBound *= cur_bound.value();
567+
}
568+
}
569+
}
570+
571+
parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
572+
}
573+
574+
if (!isOuterMostLoop) {
575+
// get original shape size
576+
auto memType = cast<MemRefType>(origAllocOp->getResultTypes().front());
577+
int64_t size = getSizeInBytes(memType);
578+
Value origSize = builder.create<arith::ConstantIndexOp>(loc, size);
579+
Value offsetPerThread =
580+
builder.create<arith::MulIOp>(loc, inductVar, origSize);
581+
Value byteShiftPerThread =
582+
builder.create<arith::AddIOp>(loc, byteShift, offsetPerThread);
583+
584+
return builder.create<memref::ViewOp>(
585+
loc, origAllocOp->getResultTypes().front(), mergedAlloc,
586+
byteShiftPerThread, ValueRange{});
587+
} else {
588+
return builder.create<memref::ViewOp>(loc,
589+
origAllocOp->getResultTypes().front(),
590+
mergedAlloc, byteShift, ValueRange{});
591+
}
399592
}
400593

401594
LogicalResult

0 commit comments

Comments
 (0)