Skip to content
This repository was archived by the owner on Sep 27, 2019. It is now read-only.

Commit d8b948d

Browse files
committed
Fix parallel hash joins to merge thread local hash tables in parallel. Also fixed memory leaks by properly cleaning up after pipelines.
1 parent cc9a025 commit d8b948d

File tree

13 files changed

+197
-87
lines changed

13 files changed

+197
-87
lines changed

src/codegen/code_context.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,14 @@ class InstructionCounts : public llvm::ModulePass {
142142
}
143143

144144
void DumpStats() const {
145-
#ifndef NDEBUG
146-
LOG_DEBUG("# functions: %" PRId64 " (%" PRId64
145+
LOG_INFO("# functions: %" PRId64 " (%" PRId64
147146
" external), # blocks: %" PRId64 ", # instructions: %" PRId64,
148147
func_count_, external_func_count_, basic_block_count_,
149148
total_inst_counts_);
150149
for (const auto iter : counts_) {
151150
const char *inst_name = llvm::Instruction::getOpcodeName(iter.first);
152-
LOG_DEBUG("↳ %s: %" PRId64, inst_name, iter.second);
151+
LOG_INFO("↳ %s: %" PRId64, inst_name, iter.second);
153152
}
154-
#endif
155153
}
156154

157155
private:

src/codegen/hash_table.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,10 @@ void HashTable::BuildLazy(CodeGen &codegen, llvm::Value *ht_ptr) const {
163163
}
164164

165165
void HashTable::ReserveLazy(CodeGen &codegen, llvm::Value *ht_ptr,
166-
llvm::Value *thread_states) const {
167-
codegen.Call(HashTableProxy::BuildLazy, {ht_ptr, thread_states});
166+
llvm::Value *thread_states,
167+
uint32_t ht_state_offset) const {
168+
codegen.Call(HashTableProxy::ReserveLazy,
169+
{ht_ptr, thread_states, codegen.Const32(ht_state_offset)});
168170
}
169171

170172
void HashTable::MergeLazyUnfinished(CodeGen &codegen, llvm::Value *global_ht,

src/codegen/operator/hash_join_translator.cpp

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ void HashJoinTranslator::Consume(ConsumerContext &context,
343343
}
344344

345345
// The given row is coming from the left child. Insert into hash table
346-
void HashJoinTranslator::ConsumeFromLeft(ConsumerContext &,
346+
void HashJoinTranslator::ConsumeFromLeft(ConsumerContext &ctx,
347347
RowBatch::Row &row) const {
348348
CodeGen &codegen = GetCodeGen();
349349

@@ -363,61 +363,69 @@ void HashJoinTranslator::ConsumeFromLeft(ConsumerContext &,
363363
}
364364
#endif
365365

366+
llvm::Value *ht_ptr = nullptr;
367+
if (ctx.GetPipeline().IsParallel()) {
368+
ht_ptr = ctx.GetPipelineContext()->LoadStatePtr(codegen, hash_table_tl_id_);
369+
} else {
370+
ht_ptr = LoadStatePtr(hash_table_id_);
371+
}
372+
366373
// Insert tuples from the left side into the hash table
367374
InsertLeft insert_left{left_value_storage_, vals};
368-
hash_table_.InsertLazy(codegen, LoadStatePtr(hash_table_id_), hash, key,
369-
insert_left);
375+
hash_table_.InsertLazy(codegen, ht_ptr, hash, key, insert_left);
370376

377+
// Update bloom filter, if enabled
371378
if (GetJoinPlan().IsBloomFilterEnabled()) {
372-
// Insert tuples into the bloom filter if enabled
373379
bloom_filter_.Add(codegen, LoadStatePtr(bloom_filter_id_), key);
374380
}
375381
}
376382

377-
void HashJoinTranslator::RegisterPipelineState(PipelineContext &context) {
378-
if (context.IsParallel()) {
379-
hash_table_tl_id_ =
380-
context.RegisterState("localHT", HashTableProxy::GetType(GetCodeGen()));
383+
void HashJoinTranslator::RegisterPipelineState(PipelineContext &pipeline_ctx) {
384+
if (pipeline_ctx.IsParallel() && IsLeftPipeline(pipeline_ctx.GetPipeline())) {
385+
hash_table_tl_id_ = pipeline_ctx.RegisterState(
386+
"localHT", HashTableProxy::GetType(GetCodeGen()));
381387
}
382388
}
383389

384-
void HashJoinTranslator::InitializePipelineState(PipelineContext &context) {
385-
if (context.IsParallel() && IsLeftPipeline(context.GetPipeline())) {
390+
void HashJoinTranslator::InitializePipelineState(
391+
PipelineContext &pipeline_ctx) {
392+
if (pipeline_ctx.IsParallel() && IsLeftPipeline(pipeline_ctx.GetPipeline())) {
386393
CodeGen &codegen = GetCodeGen();
387394
hash_table_.Init(codegen, GetExecutorContextPtr(),
388-
context.LoadStatePtr(codegen, hash_table_tl_id_));
395+
pipeline_ctx.LoadStatePtr(codegen, hash_table_tl_id_));
389396
}
390397
}
391398

392-
void HashJoinTranslator::FinishPipeline(PipelineContext &context) {
393-
// We only need to do post-pipeline processing work in the left pipeline
394-
if (context.GetPipeline() != left_pipeline_) {
395-
return;
396-
}
397-
398-
llvm::Value *global_ht_ptr = LoadStatePtr(hash_table_id_);
399-
400-
if (!context.IsParallel()) {
401-
// Build the hash table over the lazily inserted tuples
402-
hash_table_.BuildLazy(GetCodeGen(), global_ht_ptr);
403-
} else {
399+
void HashJoinTranslator::FinishPipeline(PipelineContext &pipeline_ctx) {
400+
if (IsLeftPipeline(pipeline_ctx.GetPipeline())) {
404401
CodeGen &codegen = GetCodeGen();
405-
406-
llvm::Value *local_ht_ptr =
407-
context.LoadStatePtr(codegen, hash_table_tl_id_);
408-
409-
// First size the global hash table
410-
hash_table_.ReserveLazy(codegen, global_ht_ptr, GetThreadStatesPtr());
411-
412-
// Then merge each local table in parallel
413-
hash_table_.MergeLazyUnfinished(codegen, global_ht_ptr, local_ht_ptr);
402+
llvm::Value *global_ht_ptr = LoadStatePtr(hash_table_id_);
403+
if (!pipeline_ctx.IsParallel()) {
404+
// Build the hash table over the lazily inserted tuples
405+
hash_table_.BuildLazy(codegen, global_ht_ptr);
406+
} else {
407+
// First size the global hash table
408+
hash_table_.ReserveLazy(
409+
codegen, global_ht_ptr, GetThreadStatesPtr(),
410+
pipeline_ctx.GetEntryOffset(codegen, hash_table_tl_id_));
411+
412+
// Then merge each local table in parallel
413+
PipelineContext::LoopOverStates loop_states{pipeline_ctx};
414+
loop_states.DoParallel([this, &pipeline_ctx, &codegen](
415+
UNUSED_ATTRIBUTE llvm::Value *thread_state) {
416+
llvm::Value *global_ht_ptr = LoadStatePtr(hash_table_id_);
417+
llvm::Value *local_ht_ptr =
418+
pipeline_ctx.LoadStatePtr(codegen, hash_table_tl_id_);
419+
hash_table_.MergeLazyUnfinished(codegen, global_ht_ptr, local_ht_ptr);
420+
});
421+
}
414422
}
415423
}
416424

417-
void HashJoinTranslator::TearDownPipelineState(PipelineContext &context) {
418-
if (context.IsParallel() && IsLeftPipeline(context.GetPipeline())) {
425+
void HashJoinTranslator::TearDownPipelineState(PipelineContext &pipeline_ctx) {
426+
if (pipeline_ctx.IsParallel() && IsLeftPipeline(pipeline_ctx.GetPipeline())) {
419427
CodeGen &codegen = GetCodeGen();
420-
auto *local_ht_ptr = context.LoadStatePtr(codegen, hash_table_tl_id_);
428+
auto *local_ht_ptr = pipeline_ctx.LoadStatePtr(codegen, hash_table_tl_id_);
421429
hash_table_.Destroy(codegen, local_ht_ptr);
422430
}
423431
}
@@ -434,7 +442,7 @@ void HashJoinTranslator::ConsumeFromRight(ConsumerContext &context,
434442
llvm::Value *contains = bloom_filter_.Contains(
435443
GetCodeGen(), LoadStatePtr(bloom_filter_id_), key);
436444

437-
lang::If is_valid_row(GetCodeGen(), contains);
445+
lang::If is_valid_row{GetCodeGen(), contains};
438446
{
439447
// For each tuple that passes the bloom filter, probe the hash table
440448
// to eliminate the false positives.

src/codegen/pipeline.cpp

Lines changed: 95 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,25 @@
1818
#include "codegen/consumer_context.h"
1919
#include "codegen/lang/loop.h"
2020
#include "codegen/proxy/executor_context_proxy.h"
21+
#include "codegen/proxy/runtime_functions_proxy.h"
2122
#include "settings/settings_manager.h"
2223

2324
namespace peloton {
2425
namespace codegen {
2526

27+
namespace {
28+
29+
std::string CreateUniqueFunctionName(Pipeline &pipeline,
30+
const std::string &prefix) {
31+
CompilationContext &compilation_ctx = pipeline.GetCompilationContext();
32+
CodeContext &cc = compilation_ctx.GetCodeGen().GetCodeContext();
33+
return StringUtil::Format("_%" PRId64 "_pipeline_%u_%s_%s", cc.GetID(),
34+
pipeline.GetId(), prefix.c_str(),
35+
pipeline.ConstructPipelineName().c_str());
36+
}
37+
38+
} // namespace
39+
2640
////////////////////////////////////////////////////////////////////////////////
2741
///
2842
/// LoopOverStates
@@ -36,30 +50,83 @@ void PipelineContext::LoopOverStates::Do(
3650
const std::function<void(llvm::Value *)> &body) const {
3751
auto &compilation_ctx = ctx_.GetPipeline().GetCompilationContext();
3852
auto &exec_consumer = compilation_ctx.GetExecutionConsumer();
39-
auto *thread_states = exec_consumer.GetThreadStatesPtr(compilation_ctx);
53+
auto &codegen = compilation_ctx.GetCodeGen();
4054

41-
CodeGen &codegen = compilation_ctx.GetCodeGen();
55+
llvm::Value *thread_states =
56+
exec_consumer.GetThreadStatesPtr(compilation_ctx);
4257

4358
llvm::Value *num_threads =
4459
codegen.Load(ThreadStatesProxy::num_threads, thread_states);
4560
llvm::Value *state_size =
4661
codegen.Load(ThreadStatesProxy::state_size, thread_states);
47-
llvm::Value *states = codegen.Load(ThreadStatesProxy::states, thread_states);
4862

49-
llvm::Value *state_end = codegen->CreateInBoundsGEP(
50-
states, {codegen->CreateMul(num_threads, state_size)});
63+
llvm::Value *states = codegen.Load(ThreadStatesProxy::states, thread_states);
64+
states = codegen->CreatePointerCast(states, codegen.CharPtrType());
5165

52-
llvm::Value *loop_cond = codegen->CreateICmpNE(states, state_end);
53-
lang::Loop state_loop{codegen, loop_cond, {{"threadState", states}}};
66+
llvm::Value *tid = codegen.Const32(0);
67+
llvm::Value *loop_cond = codegen->CreateICmpNE(tid, num_threads);
68+
lang::Loop state_loop{codegen, loop_cond, {{"tid", tid}}};
5469
{
55-
// Pull out state in this iteration
56-
llvm::Value *curr_state = state_loop.GetLoopVar(0);
70+
// Pull out state for current TID
71+
tid = state_loop.GetLoopVar(0);
72+
llvm::Value *offset = codegen->CreateMul(tid, state_size);
73+
74+
llvm::Value *raw_ptr = codegen->CreateInBoundsGEP(states, {offset});
75+
llvm::Value *state = codegen->CreatePointerCast(
76+
raw_ptr, ctx_.GetThreadStateType()->getPointerTo());
77+
5778
// Invoke caller
58-
body(curr_state);
79+
body(state);
80+
5981
// Wrap up
60-
states = codegen->CreateInBoundsGEP(states, {state_size});
61-
state_loop.LoopEnd(codegen->CreateICmpNE(states, state_end), {states});
82+
tid = codegen->CreateAdd(tid, codegen.Const32(1));
83+
state_loop.LoopEnd(codegen->CreateICmpNE(tid, num_threads), {tid});
84+
}
85+
}
86+
87+
void PipelineContext::LoopOverStates::DoParallel(
88+
const std::function<void(llvm::Value *)> &body) const {
89+
Pipeline &pipeline = ctx_.GetPipeline();
90+
CompilationContext &comp_ctx = pipeline.GetCompilationContext();
91+
QueryState &query_state = comp_ctx.GetQueryState();
92+
CodeGen &codegen = comp_ctx.GetCodeGen();
93+
94+
auto name = CreateUniqueFunctionName(pipeline, "loopThreadState");
95+
96+
std::vector<FunctionDeclaration::ArgumentInfo> args = {
97+
{"queryState", query_state.GetType()->getPointerTo()},
98+
{"threadState", ctx_.GetThreadStateType()->getPointerTo()}};
99+
FunctionDeclaration decl{codegen.GetCodeContext(), name,
100+
FunctionDeclaration::Visibility::Internal,
101+
codegen.VoidType(), args};
102+
FunctionBuilder func{codegen.GetCodeContext(), decl};
103+
{
104+
// Pull out arguments
105+
auto *thread_state_ptr = func.GetArgumentByPosition(1);
106+
107+
// Setup access to the thread state
108+
PipelineContext::ScopedStateAccess state_access{
109+
ctx_, func.GetArgumentByPosition(1)};
110+
111+
// Execute function body
112+
body(thread_state_ptr);
113+
114+
// Finish
115+
func.ReturnAndFinish();
62116
}
117+
118+
// Invoke the per-state dispatch function
119+
120+
std::vector<llvm::Value *> dispatch_args = {
121+
// The (void*) query state
122+
codegen->CreatePointerCast(codegen.GetState(), codegen.VoidPtrType()),
123+
// The (ThreadStates &) thread states
124+
comp_ctx.GetExecutionConsumer().GetThreadStatesPtr(comp_ctx),
125+
// The function
126+
codegen->CreatePointerCast(
127+
func.GetFunction(),
128+
proxy::TypeBuilder<void (*)(void *, void *)>::GetType(codegen))};
129+
codegen.Call(RuntimeFunctionsProxy::ExecutePerState, dispatch_args);
63130
}
64131

65132
////////////////////////////////////////////////////////////////////////////////
@@ -150,6 +217,12 @@ uint32_t PipelineContext::GetEntryOffset(CodeGen &codegen,
150217
return static_cast<uint32_t>(codegen.ElementOffset(state_type, state_id));
151218
}
152219

220+
bool PipelineContext::HasState() const {
221+
PELOTON_ASSERT(thread_state_type_ != nullptr &&
222+
"Cannot query state components until it has been finalized");
223+
return state_components_.size() > 1;
224+
}
225+
153226
bool PipelineContext::IsParallel() const { return pipeline_.IsParallel(); }
154227

155228
Pipeline &PipelineContext::GetPipeline() { return pipeline_; }
@@ -283,19 +356,6 @@ uint32_t Pipeline::GetTranslatorStage(
283356
///
284357
////////////////////////////////////////////////////////////////////////////////
285358

286-
namespace {
287-
288-
std::string CreateUniqueFunctionName(Pipeline &pipeline,
289-
const std::string &prefix) {
290-
CompilationContext &compilation_ctx = pipeline.GetCompilationContext();
291-
CodeContext &cc = compilation_ctx.GetCodeGen().GetCodeContext();
292-
return StringUtil::Format("_%" PRId64 "_pipeline_%u_%s_%s", cc.GetID(),
293-
pipeline.GetId(), prefix.c_str(),
294-
pipeline.ConstructPipelineName().c_str());
295-
}
296-
297-
} // namespace
298-
299359
std::string Pipeline::ConstructPipelineName() const {
300360
std::vector<std::string> parts;
301361
for (auto riter = pipeline_.rbegin(), rend = pipeline_.rend(); riter != rend;
@@ -354,8 +414,8 @@ void Pipeline::InitializePipeline(PipelineContext &pipeline_ctx) {
354414
{"queryState", query_state.GetType()->getPointerTo()},
355415
{"threadState", pipeline_ctx.GetThreadStateType()->getPointerTo()}};
356416

357-
FunctionDeclaration init_decl(cc, func_name, visibility, ret_type, args);
358-
FunctionBuilder init_func(cc, init_decl);
417+
FunctionDeclaration init_decl{cc, func_name, visibility, ret_type, args};
418+
FunctionBuilder init_func{cc, init_decl};
359419
{
360420
PipelineContext::ScopedStateAccess state_access{
361421
pipeline_ctx, init_func.GetArgumentByPosition(1)};
@@ -386,7 +446,11 @@ void Pipeline::CompletePipeline(PipelineContext &pipeline_ctx) {
386446
return;
387447
}
388448

389-
// Loop over all states
449+
if (!pipeline_ctx.HasState()) {
450+
return;
451+
}
452+
453+
// Loop over all states to allow operators to clean up components
390454
PipelineContext::LoopOverStates loop_state{pipeline_ctx};
391455
loop_state.Do([this, &pipeline_ctx](llvm::Value *thread_state) {
392456
PipelineContext::ScopedStateAccess state_access{pipeline_ctx, thread_state};
@@ -429,8 +493,7 @@ void Pipeline::Run(
429493
InitializePipeline(pipeline_ctx);
430494

431495
// Generate pipeline
432-
DoRun(pipeline_ctx, dispatch_func, dispatch_args, pipeline_arg_types,
433-
body);
496+
DoRun(pipeline_ctx, dispatch_func, dispatch_args, pipeline_arg_types, body);
434497

435498
// Finish
436499
CompletePipeline(pipeline_ctx);
@@ -460,8 +523,8 @@ void Pipeline::DoRun(
460523
}
461524

462525
// The main function
463-
FunctionDeclaration declaration(cc, func_name, visibility, ret_type, args);
464-
FunctionBuilder func(cc, declaration);
526+
FunctionDeclaration declaration{cc, func_name, visibility, ret_type, args};
527+
FunctionBuilder func{cc, declaration};
465528
{
466529
auto *query_state = func.GetArgumentByPosition(0);
467530
auto *thread_state = func.GetArgumentByPosition(1);

src/codegen/proxy/runtime_functions_proxy.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ DEFINE_METHOD(peloton::codegen, RuntimeFunctions, GetTileGroup);
3232
DEFINE_METHOD(peloton::codegen, RuntimeFunctions, GetTileGroupLayout);
3333
DEFINE_METHOD(peloton::codegen, RuntimeFunctions, FillPredicateArray);
3434
DEFINE_METHOD(peloton::codegen, RuntimeFunctions, ExecuteTableScan);
35+
DEFINE_METHOD(peloton::codegen, RuntimeFunctions, ExecutePerState);
3536
DEFINE_METHOD(peloton::codegen, RuntimeFunctions, ThrowDivideByZeroException);
3637
DEFINE_METHOD(peloton::codegen, RuntimeFunctions, ThrowOverflowException);
3738

0 commit comments

Comments
 (0)