Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ static LogicalResult checkAndUpdateCapabilityRequirements(
return success();
}

static void addAllImpliedCapabilities(SetVector<spirv::Capability> &caps) {
for (spirv::Capability cap : caps) {
ArrayRef<spirv::Capability> impliedCaps = getDirectImpliedCapabilities(cap);
caps.insert_range(impliedCaps);
}
}

void UpdateVCEPass::runOnOperation() {
spirv::ModuleOp module = getOperation();

Expand Down Expand Up @@ -168,6 +175,8 @@ void UpdateVCEPass::runOnOperation() {
return WalkResult::interrupt();
}

addAllImpliedCapabilities(deducedCapabilities);

return WalkResult::advance();
});

Expand Down
32 changes: 16 additions & 16 deletions mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// Test deducing minimal version.
// spirv.IAdd is available from v1.0.

// CHECK: requires #spirv.vce<v1.0, [Shader], []>
// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [Shader], []>, #spirv.resource_limits<>>
Expand All @@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes {
// Test deducing minimal version.
// spirv.GroupNonUniformBallot is available since v1.3.

// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformBallot, Shader], []>
// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [Shader, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
Expand All @@ -32,7 +32,7 @@ spirv.module Logical GLSL450 attributes {
}
}

// CHECK: requires #spirv.vce<v1.4, [Shader], []>
// CHECK: requires #spirv.vce<v1.4, [Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader], []>, #spirv.resource_limits<>>
} {
Expand All @@ -48,7 +48,7 @@ spirv.module Logical GLSL450 attributes {

// Test minimal capabilities.

// CHECK: requires #spirv.vce<v1.0, [Shader], []>
// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, Float16, Float64, Int16, Int64, VariablePointers], []>, #spirv.resource_limits<>>
Expand All @@ -61,10 +61,10 @@ spirv.module Logical GLSL450 attributes {

// Test Physical Storage Buffers are deduced correctly.

// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [PhysicalStorageBufferAddresses, Shader], [SPV_EXT_physical_storage_buffer]>
// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [PhysicalStorageBufferAddresses, Shader, Matrix], [SPV_EXT_physical_storage_buffer]>
spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
#spirv.vce<v1.0, [PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
} {
spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
spirv.Return
Expand All @@ -74,7 +74,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
// Test deducing implied capability.
// AtomicStorage implies Shader.

// CHECK: requires #spirv.vce<v1.0, [Shader], []>
// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [AtomicStorage], []>, #spirv.resource_limits<>>
Expand All @@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes {
// * GroupNonUniformArithmetic
// * GroupNonUniformBallot

// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformArithmetic, Shader], []>
// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformArithmetic, GroupNonUniform, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spirv.resource_limits<>>
Expand All @@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes {
}
}

// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, Shader], []>
// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, GroupNonUniformClustered, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
Expand All @@ -120,7 +120,7 @@ spirv.module Logical GLSL450 attributes {
// Test type required capabilities

// Using 8-bit integers in non-interface storage class requires Int8.
// CHECK: requires #spirv.vce<v1.0, [Int8, Shader], []>
// CHECK: requires #spirv.vce<v1.0, [Int8, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, Int8], []>, #spirv.resource_limits<>>
Expand All @@ -132,7 +132,7 @@ spirv.module Logical GLSL450 attributes {
}

// Using 16-bit floats in non-interface storage class requires Float16.
// CHECK: requires #spirv.vce<v1.0, [Float16, Shader], []>
// CHECK: requires #spirv.vce<v1.0, [Float16, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, Float16], []>, #spirv.resource_limits<>>
Expand All @@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes {
}

// Using 16-element vectors requires Vector16.
// CHECK: requires #spirv.vce<v1.0, [Vector16, Shader], []>
// CHECK: requires #spirv.vce<v1.0, [Vector16, Kernel, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, Vector16], []>, #spirv.resource_limits<>>
Expand All @@ -162,7 +162,7 @@ spirv.module Logical GLSL450 attributes {
// Test deducing minimal extensions.
// spirv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension.

// CHECK: requires #spirv.vce<v1.0, [SubgroupBallotKHR, Shader], [SPV_KHR_shader_ballot]>
// CHECK: requires #spirv.vce<v1.0, [SubgroupBallotKHR, Shader, Matrix], [SPV_KHR_shader_ballot]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, SubgroupBallotKHR],
Expand Down Expand Up @@ -193,7 +193,7 @@ spirv.module Logical Vulkan attributes {

// Using 8-bit integers in interface storage class requires additional
// extensions and capabilities.
// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16, Matrix], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, StorageBuffer16BitAccess, Int16], []>, #spirv.resource_limits<>>
Expand All @@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes {
// Complicated nested types
// * Buffer requires ImageBuffer or SampledBuffer.
// * Rg32f requires StorageImageExtendedFormats.
// CHECK: requires #spirv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
// CHECK: requires #spirv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, StorageBuffer8BitAccess, StorageBuffer16BitAccess, Matrix, ImageBuffer, StorageImageExtendedFormats, SampledBuffer], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, Int64, ImageBuffer, StorageImageExtendedFormats], []>,
Expand All @@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes {
}

// Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension.
// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR, Matrix], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, StorageBuffer16BitAccess, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>,
Expand Down
Loading