@@ -270,6 +270,9 @@ void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept {
270
270
// ComputePipeline
271
271
//
272
272
273
+ ComputePipeline::ComputePipeline (VkDevice device, VkPipeline handle)
274
+ : device_{device}, handle_{handle} {}
275
+
273
276
ComputePipeline::ComputePipeline (
274
277
VkDevice device,
275
278
const ComputePipeline::Descriptor& descriptor,
@@ -444,19 +447,94 @@ ComputePipelineCache::~ComputePipelineCache() {
444
447
pipeline_cache_ = VK_NULL_HANDLE;
445
448
}
446
449
450
+ bool ComputePipelineCache::contains (const ComputePipelineCache::Key& key) {
451
+ std::lock_guard<std::mutex> lock (cache_mutex_);
452
+
453
+ auto it = cache_.find (key);
454
+ return it != cache_.cend ();
455
+ }
456
+
457
+ void ComputePipelineCache::create_pipelines (
458
+ const std::unordered_set<Key, Hasher>& descriptors) {
459
+ std::lock_guard<std::mutex> lock (cache_mutex_);
460
+
461
+ const auto num_pipelines = descriptors.size ();
462
+ std::vector<VkPipeline> pipelines (num_pipelines);
463
+
464
+ std::vector<std::vector<VkSpecializationMapEntry>> map_entries;
465
+ map_entries.reserve (num_pipelines);
466
+
467
+ std::vector<VkSpecializationInfo> specialization_infos;
468
+ specialization_infos.reserve (num_pipelines);
469
+
470
+ std::vector<VkPipelineShaderStageCreateInfo> shader_stage_create_infos;
471
+ shader_stage_create_infos.reserve (num_pipelines);
472
+
473
+ std::vector<VkComputePipelineCreateInfo> create_infos;
474
+ create_infos.reserve (num_pipelines);
475
+
476
+ for (auto & key : descriptors) {
477
+ map_entries.push_back (key.specialization_constants .generate_map_entries ());
478
+
479
+ specialization_infos.push_back (VkSpecializationInfo{
480
+ key.specialization_constants .size (), // mapEntryCount
481
+ map_entries.back ().data (), // pMapEntries
482
+ key.specialization_constants .data_nbytes (), // dataSize
483
+ key.specialization_constants .data (), // pData
484
+ });
485
+
486
+ shader_stage_create_infos.push_back (VkPipelineShaderStageCreateInfo{
487
+ VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // sType
488
+ nullptr , // pNext
489
+ 0u , // flags
490
+ VK_SHADER_STAGE_COMPUTE_BIT, // stage
491
+ key.shader_module , // module
492
+ " main" , // pName
493
+ &specialization_infos.back (), // pSpecializationInfo
494
+ });
495
+
496
+ create_infos.push_back (VkComputePipelineCreateInfo{
497
+ VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // sType
498
+ nullptr , // pNext
499
+ 0u , // flags
500
+ shader_stage_create_infos.back (), // stage
501
+ key.pipeline_layout , // layout
502
+ VK_NULL_HANDLE, // basePipelineHandle
503
+ 0u , // basePipelineIndex
504
+ });
505
+ }
506
+
507
+ VK_CHECK (vkCreateComputePipelines (
508
+ device_,
509
+ pipeline_cache_,
510
+ create_infos.size (),
511
+ create_infos.data (),
512
+ nullptr ,
513
+ pipelines.data ()));
514
+
515
+ uint32_t i = 0 ;
516
+ for (auto & key : descriptors) {
517
+ auto it = cache_.find (key);
518
+ if (it != cache_.cend ()) {
519
+ continue ;
520
+ }
521
+ cache_.insert ({key, ComputePipelineCache::Value (device_, pipelines[i])});
522
+ ++i;
523
+ }
524
+ }
525
+
447
526
VkPipeline ComputePipelineCache::retrieve (
448
527
const ComputePipelineCache::Key& key) {
449
528
std::lock_guard<std::mutex> lock (cache_mutex_);
450
529
451
530
auto it = cache_.find (key);
452
- if (cache_.cend () == it ) {
531
+ if (it == cache_.cend ()) {
453
532
it = cache_
454
533
.insert (
455
534
{key,
456
535
ComputePipelineCache::Value (device_, key, pipeline_cache_)})
457
536
.first ;
458
537
}
459
-
460
538
return it->second .handle ();
461
539
}
462
540
0 commit comments