18
18
#include " codegen/consumer_context.h"
19
19
#include " codegen/lang/loop.h"
20
20
#include " codegen/proxy/executor_context_proxy.h"
21
+ #include " codegen/proxy/runtime_functions_proxy.h"
21
22
#include " settings/settings_manager.h"
22
23
23
24
namespace peloton {
24
25
namespace codegen {
25
26
27
+ namespace {
28
+
29
+ std::string CreateUniqueFunctionName (Pipeline &pipeline,
30
+ const std::string &prefix) {
31
+ CompilationContext &compilation_ctx = pipeline.GetCompilationContext ();
32
+ CodeContext &cc = compilation_ctx.GetCodeGen ().GetCodeContext ();
33
+ return StringUtil::Format (" _%" PRId64 " _pipeline_%u_%s_%s" , cc.GetID (),
34
+ pipeline.GetId (), prefix.c_str (),
35
+ pipeline.ConstructPipelineName ().c_str ());
36
+ }
37
+
38
+ } // namespace
39
+
26
40
// //////////////////////////////////////////////////////////////////////////////
27
41
// /
28
42
// / LoopOverStates
@@ -36,30 +50,83 @@ void PipelineContext::LoopOverStates::Do(
36
50
const std::function<void (llvm::Value *)> &body) const {
37
51
auto &compilation_ctx = ctx_.GetPipeline ().GetCompilationContext ();
38
52
auto &exec_consumer = compilation_ctx.GetExecutionConsumer ();
39
- auto *thread_states = exec_consumer. GetThreadStatesPtr (compilation_ctx );
53
+ auto &codegen = compilation_ctx. GetCodeGen ( );
40
54
41
- CodeGen &codegen = compilation_ctx.GetCodeGen ();
55
+ llvm::Value *thread_states =
56
+ exec_consumer.GetThreadStatesPtr (compilation_ctx);
42
57
43
58
llvm::Value *num_threads =
44
59
codegen.Load (ThreadStatesProxy::num_threads, thread_states);
45
60
llvm::Value *state_size =
46
61
codegen.Load (ThreadStatesProxy::state_size, thread_states);
47
- llvm::Value *states = codegen.Load (ThreadStatesProxy::states, thread_states);
48
62
49
- llvm::Value *state_end = codegen-> CreateInBoundsGEP (
50
- states, { codegen->CreateMul (num_threads, state_size)} );
63
+ llvm::Value *states = codegen. Load (ThreadStatesProxy::states, thread_states);
64
+ states = codegen->CreatePointerCast (states, codegen. CharPtrType () );
51
65
52
- llvm::Value *loop_cond = codegen->CreateICmpNE (states, state_end);
53
- lang::Loop state_loop{codegen, loop_cond, {{" threadState" , states}}};
66
+ llvm::Value *tid = codegen.Const32 (0 );
67
+ llvm::Value *loop_cond = codegen->CreateICmpNE (tid, num_threads);
68
+ lang::Loop state_loop{codegen, loop_cond, {{" tid" , tid}}};
54
69
{
55
- // Pull out state in this iteration
56
- llvm::Value *curr_state = state_loop.GetLoopVar (0 );
70
+ // Pull out state for current TID
71
+ tid = state_loop.GetLoopVar (0 );
72
+ llvm::Value *offset = codegen->CreateMul (tid, state_size);
73
+
74
+ llvm::Value *raw_ptr = codegen->CreateInBoundsGEP (states, {offset});
75
+ llvm::Value *state = codegen->CreatePointerCast (
76
+ raw_ptr, ctx_.GetThreadStateType ()->getPointerTo ());
77
+
57
78
// Invoke caller
58
- body (curr_state);
79
+ body (state);
80
+
59
81
// Wrap up
60
- states = codegen->CreateInBoundsGEP (states, {state_size});
61
- state_loop.LoopEnd (codegen->CreateICmpNE (states, state_end), {states});
82
+ tid = codegen->CreateAdd (tid, codegen.Const32 (1 ));
83
+ state_loop.LoopEnd (codegen->CreateICmpNE (tid, num_threads), {tid});
84
+ }
85
+ }
86
+
87
+ void PipelineContext::LoopOverStates::DoParallel (
88
+ const std::function<void (llvm::Value *)> &body) const {
89
+ Pipeline &pipeline = ctx_.GetPipeline ();
90
+ CompilationContext &comp_ctx = pipeline.GetCompilationContext ();
91
+ QueryState &query_state = comp_ctx.GetQueryState ();
92
+ CodeGen &codegen = comp_ctx.GetCodeGen ();
93
+
94
+ auto name = CreateUniqueFunctionName (pipeline, " loopThreadState" );
95
+
96
+ std::vector<FunctionDeclaration::ArgumentInfo> args = {
97
+ {" queryState" , query_state.GetType ()->getPointerTo ()},
98
+ {" threadState" , ctx_.GetThreadStateType ()->getPointerTo ()}};
99
+ FunctionDeclaration decl{codegen.GetCodeContext (), name,
100
+ FunctionDeclaration::Visibility::Internal,
101
+ codegen.VoidType (), args};
102
+ FunctionBuilder func{codegen.GetCodeContext (), decl};
103
+ {
104
+ // Pull out arguments
105
+ auto *thread_state_ptr = func.GetArgumentByPosition (1 );
106
+
107
+ // Setup access to the thread state
108
+ PipelineContext::ScopedStateAccess state_access{
109
+ ctx_, func.GetArgumentByPosition (1 )};
110
+
111
+ // Execute function body
112
+ body (thread_state_ptr);
113
+
114
+ // Finish
115
+ func.ReturnAndFinish ();
62
116
}
117
+
118
+ // Invoke the per-state dispatch function
119
+
120
+ std::vector<llvm::Value *> dispatch_args = {
121
+ // The (void*) query state
122
+ codegen->CreatePointerCast (codegen.GetState (), codegen.VoidPtrType ()),
123
+ // The (ThreadStates &) thread states
124
+ comp_ctx.GetExecutionConsumer ().GetThreadStatesPtr (comp_ctx),
125
+ // The function
126
+ codegen->CreatePointerCast (
127
+ func.GetFunction (),
128
+ proxy::TypeBuilder<void (*)(void *, void *)>::GetType (codegen))};
129
+ codegen.Call (RuntimeFunctionsProxy::ExecutePerState, dispatch_args);
63
130
}
64
131
65
132
// //////////////////////////////////////////////////////////////////////////////
@@ -150,6 +217,12 @@ uint32_t PipelineContext::GetEntryOffset(CodeGen &codegen,
150
217
return static_cast <uint32_t >(codegen.ElementOffset (state_type, state_id));
151
218
}
152
219
220
+ bool PipelineContext::HasState () const {
221
+ PELOTON_ASSERT (thread_state_type_ != nullptr &&
222
+ " Cannot query state components until it has been finalized" );
223
+ return state_components_.size () > 1 ;
224
+ }
225
+
153
226
bool PipelineContext::IsParallel () const { return pipeline_.IsParallel (); }
154
227
155
228
Pipeline &PipelineContext::GetPipeline () { return pipeline_; }
@@ -283,19 +356,6 @@ uint32_t Pipeline::GetTranslatorStage(
283
356
// /
284
357
// //////////////////////////////////////////////////////////////////////////////
285
358
286
- namespace {
287
-
288
- std::string CreateUniqueFunctionName (Pipeline &pipeline,
289
- const std::string &prefix) {
290
- CompilationContext &compilation_ctx = pipeline.GetCompilationContext ();
291
- CodeContext &cc = compilation_ctx.GetCodeGen ().GetCodeContext ();
292
- return StringUtil::Format (" _%" PRId64 " _pipeline_%u_%s_%s" , cc.GetID (),
293
- pipeline.GetId (), prefix.c_str (),
294
- pipeline.ConstructPipelineName ().c_str ());
295
- }
296
-
297
- } // namespace
298
-
299
359
std::string Pipeline::ConstructPipelineName () const {
300
360
std::vector<std::string> parts;
301
361
for (auto riter = pipeline_.rbegin (), rend = pipeline_.rend (); riter != rend;
@@ -354,8 +414,8 @@ void Pipeline::InitializePipeline(PipelineContext &pipeline_ctx) {
354
414
{" queryState" , query_state.GetType ()->getPointerTo ()},
355
415
{" threadState" , pipeline_ctx.GetThreadStateType ()->getPointerTo ()}};
356
416
357
- FunctionDeclaration init_decl ( cc, func_name, visibility, ret_type, args) ;
358
- FunctionBuilder init_func ( cc, init_decl) ;
417
+ FunctionDeclaration init_decl{ cc, func_name, visibility, ret_type, args} ;
418
+ FunctionBuilder init_func{ cc, init_decl} ;
359
419
{
360
420
PipelineContext::ScopedStateAccess state_access{
361
421
pipeline_ctx, init_func.GetArgumentByPosition (1 )};
@@ -386,7 +446,11 @@ void Pipeline::CompletePipeline(PipelineContext &pipeline_ctx) {
386
446
return ;
387
447
}
388
448
389
- // Loop over all states
449
+ if (!pipeline_ctx.HasState ()) {
450
+ return ;
451
+ }
452
+
453
+ // Loop over all states to allow operators to clean up components
390
454
PipelineContext::LoopOverStates loop_state{pipeline_ctx};
391
455
loop_state.Do ([this , &pipeline_ctx](llvm::Value *thread_state) {
392
456
PipelineContext::ScopedStateAccess state_access{pipeline_ctx, thread_state};
@@ -429,8 +493,7 @@ void Pipeline::Run(
429
493
InitializePipeline (pipeline_ctx);
430
494
431
495
// Generate pipeline
432
- DoRun (pipeline_ctx, dispatch_func, dispatch_args, pipeline_arg_types,
433
- body);
496
+ DoRun (pipeline_ctx, dispatch_func, dispatch_args, pipeline_arg_types, body);
434
497
435
498
// Finish
436
499
CompletePipeline (pipeline_ctx);
@@ -460,8 +523,8 @@ void Pipeline::DoRun(
460
523
}
461
524
462
525
// The main function
463
- FunctionDeclaration declaration ( cc, func_name, visibility, ret_type, args) ;
464
- FunctionBuilder func ( cc, declaration) ;
526
+ FunctionDeclaration declaration{ cc, func_name, visibility, ret_type, args} ;
527
+ FunctionBuilder func{ cc, declaration} ;
465
528
{
466
529
auto *query_state = func.GetArgumentByPosition (0 );
467
530
auto *thread_state = func.GetArgumentByPosition (1 );
0 commit comments