@@ -196,19 +196,18 @@ bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
196196 return false ;
197197}
198198
199- std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass (
200- ValidationState_t& _, const Instruction* inst) {
201- spv::StorageClass dst_sc = spv::StorageClass::Max ;
202- spv::StorageClass src_sc = spv::StorageClass::Max ;
199+ std::pair<Instruction*, Instruction*> GetPointerTypes (ValidationState_t& _,
200+ const Instruction* inst) {
201+ Instruction* dst_pointer_type = nullptr ;
202+ Instruction* src_pointer_type = nullptr ;
203203 switch (inst->opcode ()) {
204204 case spv::Op::OpCooperativeMatrixLoadNV:
205205 case spv::Op::OpCooperativeMatrixLoadTensorNV:
206206 case spv::Op::OpCooperativeMatrixLoadKHR:
207207 case spv::Op::OpCooperativeVectorLoadNV:
208208 case spv::Op::OpLoad: {
209209 auto load_pointer = _.FindDef (inst->GetOperandAs <uint32_t >(2 ));
210- auto load_pointer_type = _.FindDef (load_pointer->type_id ());
211- dst_sc = load_pointer_type->GetOperandAs <spv::StorageClass>(1 );
210+ dst_pointer_type = _.FindDef (load_pointer->type_id ());
212211 break ;
213212 }
214213 case spv::Op::OpCooperativeMatrixStoreNV:
@@ -217,25 +216,23 @@ std::pair<spv::StorageClass, spv::StorageClass> GetStorageClass(
217216 case spv::Op::OpCooperativeVectorStoreNV:
218217 case spv::Op::OpStore: {
219218 auto store_pointer = _.FindDef (inst->GetOperandAs <uint32_t >(0 ));
220- auto store_pointer_type = _.FindDef (store_pointer->type_id ());
221- dst_sc = store_pointer_type->GetOperandAs <spv::StorageClass>(1 );
219+ dst_pointer_type = _.FindDef (store_pointer->type_id ());
222220 break ;
223221 }
222+ // Spec: "Matching Storage Class is not required"
224223 case spv::Op::OpCopyMemory:
225224 case spv::Op::OpCopyMemorySized: {
226- auto dst = _.FindDef (inst->GetOperandAs <uint32_t >(0 ));
227- auto dst_type = _.FindDef (dst->type_id ());
228- dst_sc = dst_type->GetOperandAs <spv::StorageClass>(1 );
229- auto src = _.FindDef (inst->GetOperandAs <uint32_t >(1 ));
230- auto src_type = _.FindDef (src->type_id ());
231- src_sc = src_type->GetOperandAs <spv::StorageClass>(1 );
225+ auto dst_pointer = _.FindDef (inst->GetOperandAs <uint32_t >(0 ));
226+ dst_pointer_type = _.FindDef (dst_pointer->type_id ());
227+ auto src_pointer = _.FindDef (inst->GetOperandAs <uint32_t >(1 ));
228+ src_pointer_type = _.FindDef (src_pointer->type_id ());
232229 break ;
233230 }
234231 default :
235232 break ;
236233 }
237234
238- return std::make_pair (dst_sc, src_sc );
235+ return std::make_pair (dst_pointer_type, src_pointer_type );
239236}
240237
241238// Returns the number of instruction words taken up by a memory access
@@ -288,8 +285,17 @@ bool DoesStructContainRTA(const ValidationState_t& _, const Instruction* inst) {
288285
289286spv_result_t CheckMemoryAccess (ValidationState_t& _, const Instruction* inst,
290287 uint32_t index) {
291- spv::StorageClass dst_sc, src_sc;
292- std::tie (dst_sc, src_sc) = GetStorageClass (_, inst);
288+ Instruction* dst_pointer_type = nullptr ;
289+ Instruction* src_pointer_type = nullptr ; // only used for OpCopyMemory
290+ std::tie (dst_pointer_type, src_pointer_type) = GetPointerTypes (_, inst);
291+
292+ const spv::StorageClass dst_sc =
293+ dst_pointer_type ? dst_pointer_type->GetOperandAs <spv::StorageClass>(1 )
294+ : spv::StorageClass::Max;
295+ const spv::StorageClass src_sc =
296+ src_pointer_type ? src_pointer_type->GetOperandAs <spv::StorageClass>(1 )
297+ : spv::StorageClass::Max;
298+
293299 if (inst->operands ().size () <= index) {
294300 // Cases where lack of some operand is invalid
295301 if (src_sc == spv::StorageClass::PhysicalStorageBuffer ||
@@ -390,6 +396,23 @@ spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
390396 << " Memory accesses Aligned operand value " << aligned_value
391397 << " is not a power of two." ;
392398 }
399+
400+ uint32_t largest_scalar = 0 ;
401+ if (dst_sc == spv::StorageClass::PhysicalStorageBuffer) {
402+ largest_scalar =
403+ _.GetLargestScalarType (dst_pointer_type->GetOperandAs <uint32_t >(2 ));
404+ }
405+ if (src_sc == spv::StorageClass::PhysicalStorageBuffer) {
406+ largest_scalar = std::max (
407+ largest_scalar,
408+ _.GetLargestScalarType (src_pointer_type->GetOperandAs <uint32_t >(2 )));
409+ }
410+ if (aligned_value < largest_scalar) {
411+ return _.diag (SPV_ERROR_INVALID_ID, inst)
412+ << _.VkErrorID (6314 ) << " Memory accesses Aligned operand value "
413+ << aligned_value << " is too small, the largest scalar type is "
414+ << largest_scalar << " bytes." ;
415+ }
393416 }
394417
395418 return SPV_SUCCESS;
0 commit comments