@@ -31,14 +31,11 @@ namespace codegen {
31
31
32
32
PipelineContext::PipelineContext (Pipeline &pipeline)
33
33
: pipeline_(pipeline),
34
+ init_flag_id_ (0 ),
34
35
thread_state_type_(nullptr ),
35
36
thread_state_(nullptr ),
36
37
thread_init_func_(nullptr ),
37
- pipeline_func_(nullptr ) {
38
- // Make room for the bool flag indicating validity
39
- CodeGen &codegen = pipeline.GetCompilationContext ().GetCodeGen ();
40
- state_components_.emplace_back (" initialized" , codegen.BoolType ());
41
- }
38
+ pipeline_func_(nullptr ) {}
42
39
43
40
PipelineContext::Id PipelineContext::RegisterState (std::string name,
44
41
llvm::Type *type) {
@@ -59,6 +56,9 @@ void PipelineContext::FinalizeState(CodeGen &codegen) {
59
56
return ;
60
57
}
61
58
59
+ // Tag on the initialization flag at the end
60
+ init_flag_id_ = RegisterState (" initialized" , codegen.BoolType ());
61
+
62
62
// Pull out types
63
63
std::vector<llvm::Type *> types;
64
64
for (const auto &slot_info : state_components_) {
@@ -77,13 +77,13 @@ llvm::Value *PipelineContext::AccessThreadState(
77
77
}
78
78
79
79
llvm::Value *PipelineContext::LoadFlag (CodeGen &codegen) const {
80
- return LoadState (codegen, kFlagOffset );
80
+ return LoadState (codegen, init_flag_id_ );
81
81
}
82
82
83
83
void PipelineContext::StoreFlag (CodeGen &codegen, llvm::Value *flag) const {
84
84
PL_ASSERT (flag->getType ()->isIntegerTy (1 ) &&
85
85
flag->getType () == codegen.BoolType ());
86
- auto *flag_ptr = LoadStatePtr (codegen, kFlagOffset );
86
+ auto *flag_ptr = LoadStatePtr (codegen, init_flag_id_ );
87
87
codegen->CreateStore (flag, flag_ptr);
88
88
}
89
89
@@ -455,7 +455,7 @@ void Pipeline::DoRun(
455
455
456
456
// If the pipeline is parallel, we need to call the generated init function
457
457
if (IsParallel ()) {
458
- thread_state = codegen->CreateBitOrPointerCast (
458
+ thread_state = codegen->CreatePointerCast (
459
459
thread_state, pipeline_context.GetThreadStateType ()->getPointerTo ());
460
460
461
461
auto *init_func = pipeline_context.thread_init_func_ ;
@@ -485,33 +485,45 @@ void Pipeline::DoRun(
485
485
}
486
486
pipeline_context.pipeline_func_ = func.GetFunction ();
487
487
488
- // The launch argument starts with QueryState and ThreadState. We pass in
489
- // NULL for serial execution pipelines.
490
- std::vector<llvm::Value *> new_dispatch_args = {codegen.GetState ()};
488
+ // The pipeline function we generated above encapsulates the logic for all
489
+ // operators in the pipeline. If we're executing it serially then we directly
490
+ // invoke the function now. If the pipeline is run in parallel then a dispatch
491
+ // function must have been provided. Either way, we need to setup the call
492
+ // now.
493
+ //
494
+ // In both cases, the pipeline function expects QueryState and ThreadState
495
+ // pointers are the first two arguments. When run serially, we pass in a NULL
496
+ // thread state pointer. When running in parallel (through a dispatch
497
+ // function), we need to convert the QueryState type to a void * because it is
498
+ // a runtime generated type (i.e., pre-compiled code doesn't know the layout
499
+ // since it's dynamic)
500
+ //
501
+ // After this, the next arguments are whatever the caller provided to use.
502
+ //
503
+ // Finally, if the pipeline is run through a dispatcher function, the last
504
+ // argument is a function pointer to the pipeline function we generated.
505
+
506
+ std::vector<llvm::Value *> invoke_args = {codegen.GetState ()};
491
507
if (IsParallel ()) {
492
508
auto &consumer = compilation_ctx_.GetExecutionConsumer ();
493
- new_dispatch_args .push_back (consumer.GetThreadStatesPtr (compilation_ctx_));
509
+ invoke_args .push_back (consumer.GetThreadStatesPtr (compilation_ctx_));
494
510
} else {
495
- new_dispatch_args .push_back (codegen.NullPtr (codegen.CharPtrType ()));
511
+ invoke_args .push_back (codegen.NullPtr (codegen.CharPtrType ()));
496
512
}
497
513
498
- // Now insert the arguments the caller wants
499
- new_dispatch_args. insert (new_dispatch_args. end (), dispatch_args.begin (),
500
- dispatch_args. end ());
514
+ invoke_args. insert (invoke_args. end (), dispatch_args. begin (),
515
+ dispatch_args.end ());
516
+
501
517
if (dispatch_func != nullptr ) {
502
- // If we have a launch function, we need to cast the QueryState parameter
503
- // to a void*. This is because it is a JITed data-structure which
504
- // pre-compiled code has no notion of.
505
- //
506
- // We also append the pipeline function to the end of the arguments
507
- new_dispatch_args[0 ] = codegen->CreateBitOrPointerCast (
508
- new_dispatch_args[0 ], codegen.VoidPtrType ());
509
- new_dispatch_args.push_back (
518
+ // Convert QueryState to void *
519
+ invoke_args[0 ] =
520
+ codegen->CreateBitOrPointerCast (invoke_args[0 ], codegen.VoidPtrType ());
521
+ // Tag on the pipeline function
522
+ invoke_args.push_back (
510
523
codegen->CreateBitCast (func.GetFunction (), codegen.VoidPtrType ()));
511
- codegen.CallFunc (dispatch_func, new_dispatch_args );
524
+ codegen.CallFunc (dispatch_func, invoke_args );
512
525
} else {
513
- // Immediately invoke the pipeline function
514
- codegen.CallFunc (func.GetFunction (), new_dispatch_args);
526
+ codegen.CallFunc (func.GetFunction (), invoke_args);
515
527
}
516
528
}
517
529
0 commit comments