Skip to content

Commit eb500be

Browse files
spirv-val: Validate PhysicalStorageBuffer Align are large enough (KhronosGroup#6266)
* spirv-val: Validate PhysicalStorageBuffer Align are large enough * spirv-val: Move pointer to GetBitWidth * spirv-val: Add UntypedPointer test * spirv-val: Add Untyped PSB test
1 parent 8a8bb6c commit eb500be

File tree

4 files changed

+510
-30
lines changed

4 files changed

+510
-30
lines changed

source/val/validate_memory.cpp

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -196,19 +196,18 @@ bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
196196
return false;
197197
}
198198

199-
std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
200-
ValidationState_t& _, const Instruction* inst) {
201-
spv::StorageClass dst_sc = spv::StorageClass::Max;
202-
spv::StorageClass src_sc = spv::StorageClass::Max;
199+
std::pair<Instruction*, Instruction*> GetPointerTypes(ValidationState_t& _,
200+
const Instruction* inst) {
201+
Instruction* dst_pointer_type = nullptr;
202+
Instruction* src_pointer_type = nullptr;
203203
switch (inst->opcode()) {
204204
case spv::Op::OpCooperativeMatrixLoadNV:
205205
case spv::Op::OpCooperativeMatrixLoadTensorNV:
206206
case spv::Op::OpCooperativeMatrixLoadKHR:
207207
case spv::Op::OpCooperativeVectorLoadNV:
208208
case spv::Op::OpLoad: {
209209
auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
210-
auto load_pointer_type = _.FindDef(load_pointer->type_id());
211-
dst_sc = load_pointer_type->GetOperandAs<spv::StorageClass>(1);
210+
dst_pointer_type = _.FindDef(load_pointer->type_id());
212211
break;
213212
}
214213
case spv::Op::OpCooperativeMatrixStoreNV:
@@ -217,25 +216,23 @@ std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
217216
case spv::Op::OpCooperativeVectorStoreNV:
218217
case spv::Op::OpStore: {
219218
auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
220-
auto store_pointer_type = _.FindDef(store_pointer->type_id());
221-
dst_sc = store_pointer_type->GetOperandAs<spv::StorageClass>(1);
219+
dst_pointer_type = _.FindDef(store_pointer->type_id());
222220
break;
223221
}
222+
// Spec: "Matching Storage Class is not required"
224223
case spv::Op::OpCopyMemory:
225224
case spv::Op::OpCopyMemorySized: {
226-
auto dst = _.FindDef(inst->GetOperandAs<uint32_t>(0));
227-
auto dst_type = _.FindDef(dst->type_id());
228-
dst_sc = dst_type->GetOperandAs<spv::StorageClass>(1);
229-
auto src = _.FindDef(inst->GetOperandAs<uint32_t>(1));
230-
auto src_type = _.FindDef(src->type_id());
231-
src_sc = src_type->GetOperandAs<spv::StorageClass>(1);
225+
auto dst_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
226+
dst_pointer_type = _.FindDef(dst_pointer->type_id());
227+
auto src_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(1));
228+
src_pointer_type = _.FindDef(src_pointer->type_id());
232229
break;
233230
}
234231
default:
235232
break;
236233
}
237234

238-
return std::make_pair(dst_sc, src_sc);
235+
return std::make_pair(dst_pointer_type, src_pointer_type);
239236
}
240237

241238
// Returns the number of instruction words taken up by a memory access
@@ -288,8 +285,17 @@ bool DoesStructContainRTA(const ValidationState_t& _, const Instruction* inst) {
288285

289286
spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
290287
uint32_t index) {
291-
spv::StorageClass dst_sc, src_sc;
292-
std::tie(dst_sc, src_sc) = GetStorageClass(_, inst);
288+
Instruction* dst_pointer_type = nullptr;
289+
Instruction* src_pointer_type = nullptr; // only used for OpCopyMemory
290+
std::tie(dst_pointer_type, src_pointer_type) = GetPointerTypes(_, inst);
291+
292+
const spv::StorageClass dst_sc =
293+
dst_pointer_type ? dst_pointer_type->GetOperandAs<spv::StorageClass>(1)
294+
: spv::StorageClass::Max;
295+
const spv::StorageClass src_sc =
296+
src_pointer_type ? src_pointer_type->GetOperandAs<spv::StorageClass>(1)
297+
: spv::StorageClass::Max;
298+
293299
if (inst->operands().size() <= index) {
294300
// Cases where lack of some operand is invalid
295301
if (src_sc == spv::StorageClass::PhysicalStorageBuffer ||
@@ -390,6 +396,23 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
390396
<< "Memory accesses Aligned operand value " << aligned_value
391397
<< " is not a power of two.";
392398
}
399+
400+
uint32_t largest_scalar = 0;
401+
if (dst_sc == spv::StorageClass::PhysicalStorageBuffer) {
402+
largest_scalar =
403+
_.GetLargestScalarType(dst_pointer_type->GetOperandAs<uint32_t>(2));
404+
}
405+
if (src_sc == spv::StorageClass::PhysicalStorageBuffer) {
406+
largest_scalar = std::max(
407+
largest_scalar,
408+
_.GetLargestScalarType(src_pointer_type->GetOperandAs<uint32_t>(2)));
409+
}
410+
if (aligned_value < largest_scalar) {
411+
return _.diag(SPV_ERROR_INVALID_ID, inst)
412+
<< _.VkErrorID(6314) << "Memory accesses Aligned operand value "
413+
<< aligned_value << " is too small, the largest scalar type is "
414+
<< largest_scalar << " bytes.";
415+
}
393416
}
394417

395418
return SPV_SUCCESS;

source/val/validation_state.cpp

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,8 @@ uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
903903
case spv::Op::OpTypeFloat:
904904
case spv::Op::OpTypeInt:
905905
case spv::Op::OpTypeBool:
906+
case spv::Op::OpTypePointer:
907+
case spv::Op::OpTypeUntypedPointerKHR:
906908
return id;
907909

908910
case spv::Op::OpTypeArray:
@@ -968,11 +970,20 @@ uint32_t ValidationState_t::GetBitWidth(uint32_t id) const {
968970
const Instruction* inst = FindDef(component_type_id);
969971
assert(inst);
970972

971-
if (inst->opcode() == spv::Op::OpTypeFloat ||
972-
inst->opcode() == spv::Op::OpTypeInt)
973-
return inst->word(2);
974-
975-
if (inst->opcode() == spv::Op::OpTypeBool) return 1;
973+
switch (inst->opcode()) {
974+
case spv::Op::OpTypeFloat:
975+
case spv::Op::OpTypeInt:
976+
return inst->word(2);
977+
case spv::Op::OpTypeBool:
978+
return 1;
979+
case spv::Op::OpTypePointer:
980+
case spv::Op::OpTypeUntypedPointerKHR:
981+
assert(inst->GetOperandAs<spv::StorageClass>(1) ==
982+
spv::StorageClass::PhysicalStorageBuffer);
983+
return 64; // all pointers to another PSB is 64-bit
984+
default:
985+
break;
986+
}
976987

977988
assert(0);
978989
return 0;
@@ -1370,6 +1381,27 @@ bool ValidationState_t::GetPointerTypeInfo(
13701381
return true;
13711382
}
13721383

1384+
uint32_t ValidationState_t::GetLargestScalarType(uint32_t id) const {
1385+
const Instruction* inst = FindDef(id);
1386+
uint32_t size = 0;
1387+
1388+
switch (inst->opcode()) {
1389+
case spv::Op::OpTypeStruct:
1390+
for (uint32_t i = 1; i < inst->operands().size(); ++i) {
1391+
const uint32_t member_size =
1392+
GetLargestScalarType(inst->GetOperandAs<uint32_t>(i));
1393+
size = std::max(size, member_size);
1394+
}
1395+
break;
1396+
default:
1397+
const uint32_t bytes = GetBitWidth(id) / 8;
1398+
size = std::max(size, bytes);
1399+
break;
1400+
}
1401+
1402+
return size;
1403+
}
1404+
13731405
bool ValidationState_t::IsAccelerationStructureType(uint32_t id) const {
13741406
const Instruction* inst = FindDef(id);
13751407
return inst && inst->opcode() == spv::Op::OpTypeAccelerationStructureKHR;
@@ -2569,6 +2601,8 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
25692601
return VUID_WRAP(VUID-StandaloneSpirv-Flat-06202);
25702602
case 6214:
25712603
return VUID_WRAP(VUID-StandaloneSpirv-OpTypeImage-06214);
2604+
case 6314:
2605+
return VUID_WRAP(VUID-StandaloneSpirv-PhysicalStorageBuffer64-06314);
25722606
case 6491:
25732607
return VUID_WRAP(VUID-StandaloneSpirv-DescriptorSet-06491);
25742608
case 6671:

source/val/validation_state.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,12 @@ class ValidationState_t {
731731
/* traverse_all_types = */ false);
732732
}
733733

734+
// Will walk the type to find the largest scalar value size.
735+
// Returns value is in bytes.
736+
// This is designed to pass in the %type from a PSB pointer
737+
// %ptr = OpTypePointer PhysicalStorageBuffer %type
738+
uint32_t GetLargestScalarType(uint32_t id) const;
739+
734740
// Returns true if |id| is a type id that contains |type| (or integer or
735741
// floating point type) of |width| bits.
736742
bool ContainsSizedIntOrFloatType(uint32_t id, spv::Op type,

0 commit comments

Comments
 (0)