Skip to content

Commit 02dac82

Browse files
Resolve code review comments
Signed-off-by: Davide Grohmann <[email protected]> Change-Id: Ib58ef4d1d24e395678c9527abdd7e96a9b1df9eb
1 parent b5156d6 commit 02dac82

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,16 @@ static LogicalResult checkAndUpdateCapabilityRequirements(
9696
}
9797

9898
static SetVector<spirv::Capability>
99-
withImpliedCapabilities(SetVector<spirv::Capability> &caps) {
100-
SetVector<spirv::Capability> allCaps(caps.begin(), caps.end());
101-
for (auto cap : caps) {
102-
ArrayRef<spirv::Capability> directCaps = getDirectImpliedCapabilities(cap);
103-
allCaps.insert(directCaps.begin(), directCaps.end());
99+
addAllImpliedCapabilities(SetVector<spirv::Capability> &caps) {
100+
SetVector<spirv::Capability> allCaps;
101+
while (!caps.empty()) {
102+
spirv::Capability cap = caps.pop_back_val();
103+
allCaps.insert(cap);
104+
ArrayRef<spirv::Capability> impliedCaps = getDirectImpliedCapabilities(cap);
105+
for (spirv::Capability impliedCap : impliedCaps) {
106+
if (!allCaps.contains(impliedCap))
107+
caps.insert(impliedCap);
108+
}
104109
}
105110
return allCaps;
106111
}
@@ -178,7 +183,7 @@ void UpdateVCEPass::runOnOperation() {
178183
return WalkResult::interrupt();
179184
}
180185

181-
deducedCapabilities = withImpliedCapabilities(deducedCapabilities);
186+
deducedCapabilities = addAllImpliedCapabilities(deducedCapabilities);
182187

183188
return WalkResult::advance();
184189
});

mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes {
2121
// Test deducing minimal version.
2222
// spirv.GroupNonUniformBallot is available since v1.3.
2323

24-
// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
24+
// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniform, GroupNonUniformBallot], []>
2525
spirv.module Logical GLSL450 attributes {
2626
spirv.target_env = #spirv.target_env<
2727
#spirv.vce<v1.5, [Shader, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -61,7 +61,7 @@ spirv.module Logical GLSL450 attributes {
6161

6262
// Test Physical Storage Buffers are deduced correctly.
6363

64-
// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [PhysicalStorageBufferAddresses, Shader, Matrix], [SPV_EXT_physical_storage_buffer]>
64+
// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [Shader, Matrix, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>
6565
spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
6666
spirv.target_env = #spirv.target_env<
6767
#spirv.vce<v1.0, [Shader, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
@@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes {
9595
// * GroupNonUniformArithmetic
9696
// * GroupNonUniformBallot
9797

98-
// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformArithmetic, GroupNonUniform, Shader, Matrix], []>
98+
// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniform, GroupNonUniformArithmetic], []>
9999
spirv.module Logical GLSL450 attributes {
100100
spirv.target_env = #spirv.target_env<
101101
#spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spirv.resource_limits<>>
@@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes {
106106
}
107107
}
108108

109-
// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
109+
// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniformClustered, GroupNonUniform, GroupNonUniformBallot], []>
110110
spirv.module Logical GLSL450 attributes {
111111
spirv.target_env = #spirv.target_env<
112112
#spirv.vce<v1.3, [Shader, GroupNonUniformClustered, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -120,7 +120,7 @@ spirv.module Logical GLSL450 attributes {
120120
// Test type required capabilities
121121

122122
// Using 8-bit integers in non-interface storage class requires Int8.
123-
// CHECK: requires #spirv.vce<v1.0, [Int8, Shader, Matrix], []>
123+
// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Int8], []>
124124
spirv.module Logical GLSL450 attributes {
125125
spirv.target_env = #spirv.target_env<
126126
#spirv.vce<v1.3, [Shader, Int8], []>, #spirv.resource_limits<>>
@@ -132,7 +132,7 @@ spirv.module Logical GLSL450 attributes {
132132
}
133133

134134
// Using 16-bit floats in non-interface storage class requires Float16.
135-
// CHECK: requires #spirv.vce<v1.0, [Float16, Shader, Matrix], []>
135+
// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Float16], []>
136136
spirv.module Logical GLSL450 attributes {
137137
spirv.target_env = #spirv.target_env<
138138
#spirv.vce<v1.3, [Shader, Float16], []>, #spirv.resource_limits<>>
@@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes {
144144
}
145145

146146
// Using 16-element vectors requires Vector16.
147-
// CHECK: requires #spirv.vce<v1.0, [Vector16, Kernel, Shader, Matrix], []>
147+
// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Kernel, Vector16], []>
148148
spirv.module Logical GLSL450 attributes {
149149
spirv.target_env = #spirv.target_env<
150150
#spirv.vce<v1.3, [Shader, Vector16], []>, #spirv.resource_limits<>>
@@ -162,7 +162,7 @@ spirv.module Logical GLSL450 attributes {
162162
// Test deducing minimal extensions.
163163
// spirv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension.
164164

165-
// CHECK: requires #spirv.vce<v1.0, [SubgroupBallotKHR, Shader, Matrix], [SPV_KHR_shader_ballot]>
165+
// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, SubgroupBallotKHR], [SPV_KHR_shader_ballot]>
166166
spirv.module Logical GLSL450 attributes {
167167
spirv.target_env = #spirv.target_env<
168168
#spirv.vce<v1.0, [Shader, SubgroupBallotKHR],
@@ -193,7 +193,7 @@ spirv.module Logical Vulkan attributes {
193193

194194
// Using 8-bit integers in interface storage class requires additional
195195
// extensions and capabilities.
196-
// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16, Matrix], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
196+
// CHECK: requires #spirv.vce<v1.0, [Int16, Shader, Matrix, StorageBuffer16BitAccess], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
197197
spirv.module Logical GLSL450 attributes {
198198
spirv.target_env = #spirv.target_env<
199199
#spirv.vce<v1.3, [Shader, StorageBuffer16BitAccess, Int16], []>, #spirv.resource_limits<>>
@@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes {
208208
// Complicated nested types
209209
// * Buffer requires ImageBuffer or SampledBuffer.
210210
// * Rg32f requires StorageImageExtendedFormats.
211-
// CHECK: requires #spirv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, StorageBuffer8BitAccess, StorageBuffer16BitAccess, Matrix, ImageBuffer, StorageImageExtendedFormats, SampledBuffer], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
211+
// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Int64, StorageUniform16, StorageBuffer16BitAccess, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess, SampledBuffer, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
212212
spirv.module Logical GLSL450 attributes {
213213
spirv.target_env = #spirv.target_env<
214214
#spirv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, Int64, ImageBuffer, StorageImageExtendedFormats], []>,
@@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes {
219219
}
220220

221221
// Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension.
222-
// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR, Matrix], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
222+
// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Matrix, Shader, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
223223
spirv.module Logical GLSL450 attributes {
224224
spirv.target_env = #spirv.target_env<
225225
#spirv.vce<v1.0, [Shader, StorageBuffer16BitAccess, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>,

0 commit comments

Comments
 (0)