Skip to content

Commit 718834a

Browse files
committed
Refactor compute shader and graphics pipeline logic
Improved device extension handling and shader input/output structures. Enhanced readability with structured initialization. Updated device selection to validate Vulkan 1.3 features and required extensions. Refined pipeline blending and color attachment setup for clarity and correctness.
1 parent 3121659 commit 718834a

File tree

2 files changed

+91
-74
lines changed

2 files changed

+91
-74
lines changed

attachments/31_compute_shader.cpp

Lines changed: 78 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class ComputeShaderApplication {
126126

127127
double lastTime = 0.0f;
128128

129-
std::vector<const char*> deviceExtensions = {
129+
std::vector<const char*> requiredDeviceExtension = {
130130
vk::KHRSwapchainExtensionName,
131131
vk::KHRSpirv14ExtensionName,
132132
vk::KHRSynchronization2ExtensionName,
@@ -256,30 +256,41 @@ class ComputeShaderApplication {
256256

257257
void pickPhysicalDevice() {
258258
std::vector<vk::raii::PhysicalDevice> devices = instance.enumeratePhysicalDevices();
259-
const auto devIter = std::ranges::find_if(devices,
260-
[&](auto const & device) {
259+
const auto devIter = std::ranges::find_if(
260+
devices,
261+
[&]( auto const & device )
262+
{
263+
// Check if the device supports the Vulkan 1.3 API version
264+
bool supportsVulkan1_3 = device.getProperties().apiVersion >= VK_API_VERSION_1_3;
265+
266+
// Check if any of the queue families support graphics operations
261267
auto queueFamilies = device.getQueueFamilyProperties();
262-
bool isSuitable = device.getProperties().apiVersion >= VK_API_VERSION_1_3;
263-
const auto qfpIter = std::ranges::find_if(queueFamilies,
264-
[]( vk::QueueFamilyProperties const & qfp )
268+
bool supportsGraphics =
269+
std::ranges::any_of( queueFamilies, []( auto const & qfp ) { return !!( qfp.queueFlags & vk::QueueFlagBits::eGraphics ); } );
270+
271+
// Check if all required device extensions are available
272+
auto availableDeviceExtensions = device.enumerateDeviceExtensionProperties();
273+
bool supportsAllRequiredExtensions =
274+
std::ranges::all_of( requiredDeviceExtension,
275+
[&availableDeviceExtensions]( auto const & requiredDeviceExtension )
265276
{
266-
return (qfp.queueFlags & vk::QueueFlagBits::eGraphics) != static_cast<vk::QueueFlags>(0);
277+
return std::ranges::any_of( availableDeviceExtensions,
278+
[requiredDeviceExtension]( auto const & availableDeviceExtension )
279+
{ return strcmp( availableDeviceExtension.extensionName, requiredDeviceExtension ) == 0; } );
267280
} );
268-
isSuitable = isSuitable && ( qfpIter != queueFamilies.end() );
269-
auto extensions = device.enumerateDeviceExtensionProperties( );
270-
bool found = true;
271-
for (auto const & extension : deviceExtensions) {
272-
auto extensionIter = std::ranges::find_if(extensions, [extension](auto const & ext) {return strcmp(ext.extensionName, extension) == 0;});
273-
found = found && extensionIter != extensions.end();
274-
}
275-
isSuitable = isSuitable && found;
276-
printf("\n");
277-
if (isSuitable) {
278-
physicalDevice = device;
279-
}
280-
return isSuitable;
281+
282+
auto features = device.template getFeatures2<vk::PhysicalDeviceFeatures2, vk::PhysicalDeviceVulkan13Features, vk::PhysicalDeviceExtendedDynamicStateFeaturesEXT>();
283+
bool supportsRequiredFeatures = features.template get<vk::PhysicalDeviceVulkan13Features>().dynamicRendering &&
284+
features.template get<vk::PhysicalDeviceExtendedDynamicStateFeaturesEXT>().extendedDynamicState;
285+
286+
return supportsVulkan1_3 && supportsGraphics && supportsAllRequiredExtensions && supportsRequiredFeatures;
281287
});
282-
if (devIter == devices.end()) {
288+
if ( devIter != devices.end() )
289+
{
290+
physicalDevice = *devIter;
291+
}
292+
else
293+
{
283294
throw std::runtime_error("failed to find a suitable GPU!");
284295
}
285296
}
@@ -289,18 +300,16 @@ class ComputeShaderApplication {
289300
std::vector<vk::QueueFamilyProperties> queueFamilyProperties = physicalDevice.getQueueFamilyProperties();
290301

291302
// get the first index into queueFamilyProperties which supports graphics and compute
292-
auto graphicsAndComputeQueueFamilyProperty =
293-
std::find_if( queueFamilyProperties.begin(),
294-
queueFamilyProperties.end(),
295-
[]( vk::QueueFamilyProperties const & qfp ) { return (qfp.queueFlags & vk::QueueFlagBits::eGraphics && qfp.queueFlags & vk::QueueFlagBits::eCompute); } );
303+
auto graphicsAndComputeQueueFamilyProperty = std::ranges::find_if(queueFamilyProperties, []( auto const & qfp )
304+
{ return (qfp.queueFlags & vk::QueueFlagBits::eGraphics && qfp.queueFlags & vk::QueueFlagBits::eCompute); } );
296305

297306
graphicsAndComputeIndex = static_cast<uint32_t>( std::distance( queueFamilyProperties.begin(), graphicsAndComputeQueueFamilyProperty ) );
298307

299308
// determine a queueFamilyIndex that supports present
300309
// first check if the graphicsIndex is good enough
301310
auto presentIndex = physicalDevice.getSurfaceSupportKHR( graphicsAndComputeIndex, surface )
302311
? graphicsAndComputeIndex
303-
: static_cast<uint32_t>( queueFamilyProperties.size() );
312+
: ~0;
304313
if ( presentIndex == queueFamilyProperties.size() )
305314
{
306315
// the graphicsIndex doesn't support present -> look for another family index that supports both
@@ -350,9 +359,13 @@ class ComputeShaderApplication {
350359
// create a Device
351360
float queuePriority = 0.0f;
352361
vk::DeviceQueueCreateInfo deviceQueueCreateInfo{ .queueFamilyIndex = graphicsAndComputeIndex, .queueCount = 1, .pQueuePriorities = &queuePriority };
353-
vk::DeviceCreateInfo deviceCreateInfo{ .pNext = &features, .queueCreateInfoCount = 1, .pQueueCreateInfos = &deviceQueueCreateInfo };
354-
deviceCreateInfo.enabledExtensionCount = deviceExtensions.size();
355-
deviceCreateInfo.ppEnabledExtensionNames = deviceExtensions.data();
362+
vk::DeviceCreateInfo deviceCreateInfo{
363+
.pNext = &features,
364+
.queueCreateInfoCount = 1,
365+
.pQueueCreateInfos = &deviceQueueCreateInfo,
366+
.enabledExtensionCount = static_cast<uint32_t>(requiredDeviceExtension.size()),
367+
.ppEnabledExtensionNames = requiredDeviceExtension.data()
368+
};
356369

357370
device = vk::raii::Device( physicalDevice, deviceCreateInfo );
358371
graphicsQueue = vk::raii::Queue( device, graphicsAndComputeIndex, 0 );
@@ -407,8 +420,10 @@ class ComputeShaderApplication {
407420

408421

409422
void createGraphicsPipeline() {
423+
vk::raii::ShaderModule fragShaderModule = createShaderModule(readFile("shaders/frag.spv"));
410424
vk::raii::ShaderModule shaderModule = createShaderModule(readFile("shaders/slang.spv"));
411425

426+
// vk::PipelineShaderStageCreateInfo fragShaderStageInfo{ .stage = vk::ShaderStageFlagBits::eFragment, .module = fragShaderModule, .pName ="main"};
412427
vk::PipelineShaderStageCreateInfo vertShaderStageInfo{ .stage = vk::ShaderStageFlagBits::eVertex, .module = shaderModule, .pName = "vertMain" };
413428
vk::PipelineShaderStageCreateInfo fragShaderStageInfo{ .stage = vk::ShaderStageFlagBits::eFragment, .module = shaderModule, .pName = "fragMain" };
414429
vk::PipelineShaderStageCreateInfo shaderStages[] = {vertShaderStageInfo, fragShaderStageInfo};
@@ -419,57 +434,53 @@ class ComputeShaderApplication {
419434
vk::PipelineVertexInputStateCreateInfo vertexInputInfo{ .vertexBindingDescriptionCount = 1, .pVertexBindingDescriptions = &bindingDescription, .vertexAttributeDescriptionCount = static_cast<uint32_t>(attributeDescriptions.size()), .pVertexAttributeDescriptions = attributeDescriptions.data() };
420435
vk::PipelineInputAssemblyStateCreateInfo inputAssembly{ .topology = vk::PrimitiveTopology::ePointList, .primitiveRestartEnable = vk::False };
421436
vk::PipelineViewportStateCreateInfo viewportState{ .viewportCount = 1, .scissorCount = 1 };
422-
vk::PipelineRasterizationStateCreateInfo rasterizer{};
423-
rasterizer.depthClampEnable = vk::False;
424-
rasterizer.rasterizerDiscardEnable = vk::False;
425-
rasterizer.polygonMode = vk::PolygonMode::eFill;
426-
rasterizer.cullMode = vk::CullModeFlagBits::eBack;
427-
rasterizer.frontFace = vk::FrontFace::eCounterClockwise;
428-
rasterizer.depthBiasEnable = vk::False;
429-
rasterizer.lineWidth = 1.0f;
437+
vk::PipelineRasterizationStateCreateInfo rasterizer{
438+
.depthClampEnable = vk::False,
439+
.rasterizerDiscardEnable = vk::False,
440+
.polygonMode = vk::PolygonMode::eFill,
441+
.cullMode = vk::CullModeFlagBits::eBack,
442+
.frontFace = vk::FrontFace::eCounterClockwise,
443+
.depthBiasEnable = vk::False,
444+
.lineWidth = 1.0f
445+
};
430446
vk::PipelineMultisampleStateCreateInfo multisampling{ .rasterizationSamples = vk::SampleCountFlagBits::e1, .sampleShadingEnable = vk::False };
431447

432-
vk::PipelineColorBlendAttachmentState colorBlendAttachment;
433-
colorBlendAttachment.colorWriteMask = vk::ColorComponentFlagBits::eR | vk::ColorComponentFlagBits::eG | vk::ColorComponentFlagBits::eB | vk::ColorComponentFlagBits::eA;
434-
colorBlendAttachment.blendEnable = vk::True;
435-
colorBlendAttachment.colorBlendOp = vk::BlendOp::eAdd;
436-
colorBlendAttachment.srcColorBlendFactor = vk::BlendFactor::eSrcAlpha;
437-
colorBlendAttachment.dstColorBlendFactor = vk::BlendFactor::eOneMinusSrcAlpha;
438-
colorBlendAttachment.alphaBlendOp = vk::BlendOp::eAdd;
439-
colorBlendAttachment.srcAlphaBlendFactor = vk::BlendFactor::eOneMinusSrcAlpha;
440-
colorBlendAttachment.dstAlphaBlendFactor = vk::BlendFactor::eZero;
448+
vk::PipelineColorBlendAttachmentState colorBlendAttachment{
449+
.blendEnable = vk::True,
450+
.srcColorBlendFactor = vk::BlendFactor::eSrcAlpha,
451+
.dstColorBlendFactor = vk::BlendFactor::eOneMinusSrcAlpha,
452+
.colorBlendOp = vk::BlendOp::eAdd,
453+
.srcAlphaBlendFactor = vk::BlendFactor::eOneMinusSrcAlpha,
454+
.dstAlphaBlendFactor = vk::BlendFactor::eZero,
455+
.alphaBlendOp = vk::BlendOp::eAdd,
456+
.colorWriteMask = vk::ColorComponentFlagBits::eR | vk::ColorComponentFlagBits::eG | vk::ColorComponentFlagBits::eB | vk::ColorComponentFlagBits::eA,
457+
};
441458

442459
vk::PipelineColorBlendStateCreateInfo colorBlending{ .logicOpEnable = vk::False, .logicOp = vk::LogicOp::eCopy, .attachmentCount = 1, .pAttachments = &colorBlendAttachment };
443-
colorBlending.blendConstants[0] = 0.0f;
444-
colorBlending.blendConstants[1] = 0.0f;
445-
colorBlending.blendConstants[2] = 0.0f;
446-
colorBlending.blendConstants[3] = 0.0f;
447460

448461
std::vector dynamicStates = {
449462
vk::DynamicState::eViewport,
450463
vk::DynamicState::eScissor
451464
};
452465
vk::PipelineDynamicStateCreateInfo dynamicState{ .dynamicStateCount = static_cast<uint32_t>(dynamicStates.size()), .pDynamicStates = dynamicStates.data() };
453466

454-
vk::PipelineLayoutCreateInfo pipelineLayoutInfo{};
455-
pipelineLayoutInfo.sType = vk::StructureType::ePipelineLayoutCreateInfo;
456-
pipelineLayoutInfo.setLayoutCount = 0;
457-
pipelineLayoutInfo.pushConstantRangeCount = 0;
467+
vk::PipelineLayoutCreateInfo pipelineLayoutInfo;
458468
pipelineLayout = vk::raii::PipelineLayout( device, pipelineLayoutInfo );
459469

460470
vk::PipelineRenderingCreateInfo pipelineRenderingCreateInfo{ .colorAttachmentCount = 1, .pColorAttachmentFormats = &swapChainImageFormat.format };
461-
vk::GraphicsPipelineCreateInfo pipelineInfo{.pNext = &pipelineRenderingCreateInfo};
462-
pipelineInfo.stageCount = 2;
463-
pipelineInfo.pStages = shaderStages;
464-
pipelineInfo.pVertexInputState = &vertexInputInfo;
465-
pipelineInfo.pInputAssemblyState = &inputAssembly;
466-
pipelineInfo.pViewportState = &viewportState;
467-
pipelineInfo.pRasterizationState = &rasterizer;
468-
pipelineInfo.pMultisampleState = &multisampling;
469-
pipelineInfo.pColorBlendState = &colorBlending;
470-
pipelineInfo.pDynamicState = &dynamicState;
471-
pipelineInfo.layout = *pipelineLayout;
472-
pipelineInfo.subpass = 0;
471+
vk::GraphicsPipelineCreateInfo pipelineInfo{ .pNext = &pipelineRenderingCreateInfo,
472+
.stageCount = 2,
473+
.pStages = shaderStages,
474+
.pVertexInputState = &vertexInputInfo,
475+
.pInputAssemblyState = &inputAssembly,
476+
.pViewportState = &viewportState,
477+
.pRasterizationState = &rasterizer,
478+
.pMultisampleState = &multisampling,
479+
.pColorBlendState = &colorBlending,
480+
.pDynamicState = &dynamicState,
481+
.layout = *pipelineLayout,
482+
.subpass = 0
483+
};
473484

474485
graphicsPipeline = vk::raii::Pipeline(device, nullptr, pipelineInfo);
475486
}
@@ -478,10 +489,7 @@ class ComputeShaderApplication {
478489
vk::raii::ShaderModule shaderModule = createShaderModule(readFile("shaders/slang.spv"));
479490

480491
vk::PipelineShaderStageCreateInfo computeShaderStageInfo{ .stage = vk::ShaderStageFlagBits::eCompute, .module = shaderModule, .pName = "compMain" };
481-
vk::PipelineLayoutCreateInfo pipelineLayoutInfo{};
482-
pipelineLayoutInfo.sType = vk::StructureType::ePipelineLayoutCreateInfo;
483-
pipelineLayoutInfo.setLayoutCount = 1;
484-
pipelineLayoutInfo.pSetLayouts = &*computeDescriptorSetLayout;
492+
vk::PipelineLayoutCreateInfo pipelineLayoutInfo{ .setLayoutCount = 1, .pSetLayouts = &*computeDescriptorSetLayout };
485493
computePipelineLayout = vk::raii::PipelineLayout( device, pipelineLayoutInfo );
486494
vk::ComputePipelineCreateInfo pipelineInfo{ .stage = computeShaderStageInfo, .layout = *computePipelineLayout };
487495
computePipeline = vk::raii::Pipeline(device, nullptr, pipelineInfo);

attachments/31_shader_compute.slang

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,30 @@ struct VSOutput
88
{
99
float4 pos : SV_Position;
1010
float pointSize : SV_PointSize;
11-
float3 fragColor;
11+
float3 fragColor : COLOR0;
12+
};
13+
14+
struct PSInput
15+
{
16+
float4 pos : SV_POSITION;
17+
float3 fragColor : COLOR0;
18+
float2 pointCoord : SV_PointCoord;
1219
};
1320

1421
[shader("vertex")]
1522
VSOutput vertMain(VSInput input) {
1623
VSOutput output;
1724
output.pointSize = 14.0;
18-
output.pos = float4(input.inPosition.xy, 1.0, 1.0);
25+
output.pos = float4(input.inPosition, 1.0, 1.0);
1926
output.fragColor = input.inColor.rgb;
27+
//output.pointCoord = (input.inPosition * float2(1.0, -1.0) + float2(1.0)) / 2.0;
2028
return output;
2129
}
2230

2331
[shader("fragment")]
24-
float4 fragMain(VSOutput vertIn) : SV_TARGET {
25-
return float4(vertIn.fragColor, 1.0);
32+
float4 fragMain(PSInput input) : SV_TARGET {
33+
float2 coord = input.pointCoord - float2(0.5);
34+
return float4(input.fragColor, 0.5 - length(coord));
2635
}
2736

2837
struct Particle {

0 commit comments

Comments
 (0)