Skip to content

Commit c0455d4

Browse files
authored
Extra restrictions for accesses of block arrays (KhronosGroup#6226)
* Extra restrictions for accesses of block arrays * If a PtrAccessChain is rooted on a block, element (if constant) must be zero * UntypedAccessChains check that block arrays must not be reinterpreted * Basic element operand checks for ptr access chains * add tests * formatting * formatting
1 parent 73d28b5 commit c0455d4

File tree

4 files changed

+394
-9
lines changed

4 files changed

+394
-9
lines changed

source/val/validate_memory.cpp

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,6 +1555,60 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
15551555
return _.diag(SPV_ERROR_INVALID_ID, inst)
15561556
<< "Base type must be a non-pointer type";
15571557
}
1558+
1559+
const auto ContainsBlock = [&_](const Instruction* type_inst) {
1560+
if (type_inst->opcode() == spv::Op::OpTypeStruct) {
1561+
if (_.HasDecoration(type_inst->id(), spv::Decoration::Block) ||
1562+
_.HasDecoration(type_inst->id(), spv::Decoration::BufferBlock)) {
1563+
return true;
1564+
}
1565+
}
1566+
return false;
1567+
};
1568+
1569+
// Block (and BufferBlock) arrays cannot be reinterpreted via untyped access
1570+
// chains.
1571+
const bool base_type_block_array =
1572+
base_type->opcode() == spv::Op::OpTypeArray &&
1573+
_.ContainsType(base_type->id(), ContainsBlock,
1574+
/* traverse_all_types = */ false);
1575+
1576+
const auto base_index = untyped_pointer ? 3 : 2;
1577+
const auto base_id = inst->GetOperandAs<uint32_t>(base_index);
1578+
auto base = _.FindDef(base_id);
1579+
// Strictly speaking this misses trivial access chains and function
1580+
// parameter chasing, but that would be a significant complication in the
1581+
// traversal.
1582+
while (base->opcode() == spv::Op::OpCopyObject) {
1583+
base = _.FindDef(base->GetOperandAs<uint32_t>(2));
1584+
}
1585+
const Instruction* base_data_type = nullptr;
1586+
if (base->opcode() == spv::Op::OpVariable) {
1587+
const auto ptr_type = _.FindDef(base->type_id());
1588+
base_data_type = _.FindDef(ptr_type->GetOperandAs<uint32_t>(2));
1589+
} else if (base->opcode() == spv::Op::OpUntypedVariableKHR) {
1590+
if (base->operands().size() > 3) {
1591+
base_data_type = _.FindDef(base->GetOperandAs<uint32_t>(3));
1592+
}
1593+
}
1594+
1595+
if (base_data_type) {
1596+
const bool base_block_array =
1597+
base_data_type->opcode() == spv::Op::OpTypeArray &&
1598+
_.ContainsType(base_data_type->id(), ContainsBlock,
1599+
/* traverse_all_types = */ false);
1600+
1601+
if (base_type_block_array != base_block_array) {
1602+
return _.diag(SPV_ERROR_INVALID_ID, inst)
1603+
<< "Both Base Type and Base must be Block or BufferBlock arrays "
1604+
"or neither can be";
1605+
} else if (base_type_block_array && base_block_array &&
1606+
base_type->id() != base_data_type->id()) {
1607+
return _.diag(SPV_ERROR_INVALID_ID, inst)
1608+
<< "If Base or Base Type is a Block or BufferBlock array, the "
1609+
"other must also be the same array";
1610+
}
1611+
}
15581612
}
15591613

15601614
// Base must be a pointer, pointing to the base of a composite object.
@@ -1845,14 +1899,34 @@ spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
18451899

18461900
const bool untyped_pointer = spvOpcodeGeneratesUntypedPointer(inst->opcode());
18471901

1848-
const auto base_id = inst->GetOperandAs<uint32_t>(2);
1849-
const auto base = _.FindDef(base_id);
1850-
const auto base_type = untyped_pointer
1851-
? _.FindDef(inst->GetOperandAs<uint32_t>(2))
1852-
: _.FindDef(base->type_id());
1902+
const auto base_idx = untyped_pointer ? 3 : 2;
1903+
const auto base = _.FindDef(inst->GetOperandAs<uint32_t>(base_idx));
1904+
const auto base_type = _.FindDef(base->type_id());
18531905
const auto base_type_storage_class =
18541906
base_type->GetOperandAs<spv::StorageClass>(1);
18551907

1908+
const auto element_idx = untyped_pointer ? 4 : 3;
1909+
const auto element = _.FindDef(inst->GetOperandAs<uint32_t>(element_idx));
1910+
const auto element_type = _.FindDef(element->type_id());
1911+
if (!element_type || element_type->opcode() != spv::Op::OpTypeInt) {
1912+
return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Element must be an integer";
1913+
}
1914+
uint64_t element_val = 0;
1915+
if (_.EvalConstantValUint64(element->id(), &element_val)) {
1916+
if (element_val != 0) {
1917+
const auto interp_type =
1918+
untyped_pointer ? _.FindDef(inst->GetOperandAs<uint32_t>(2))
1919+
: _.FindDef(base_type->GetOperandAs<uint32_t>(2));
1920+
if (interp_type->opcode() == spv::Op::OpTypeStruct &&
1921+
(_.HasDecoration(interp_type->id(), spv::Decoration::Block) ||
1922+
_.HasDecoration(interp_type->id(), spv::Decoration::BufferBlock))) {
1923+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
1924+
<< "Element must be 0 if the interpretation type is a Block- or "
1925+
"BufferBlock-decorated structure";
1926+
}
1927+
}
1928+
}
1929+
18561930
if (_.HasCapability(spv::Capability::Shader) &&
18571931
(base_type_storage_class == spv::StorageClass::Uniform ||
18581932
base_type_storage_class == spv::StorageClass::StorageBuffer ||

test/opt/eliminate_dead_member_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,7 @@ TEST_F(EliminateDeadMemberTest, RemoveMemberPtrAccessChain) {
958958
; CHECK: OpMemberDecorate %type__Globals 1 Offset 16
959959
; CHECK: %type__Globals = OpTypeStruct %float %float
960960
; CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_Uniform_type__Globals %_Globals %uint_0
961-
; CHECK: OpPtrAccessChain %_ptr_Uniform_float [[ac]] %uint_1 %uint_0
961+
; CHECK: OpPtrAccessChain %_ptr_Uniform_float [[ac]] %uint_0 %uint_0
962962
; CHECK: OpPtrAccessChain %_ptr_Uniform_float [[ac]] %uint_0 %uint_1
963963
OpCapability Shader
964964
OpCapability VariablePointersStorageBuffer
@@ -995,7 +995,7 @@ TEST_F(EliminateDeadMemberTest, RemoveMemberPtrAccessChain) {
995995
%main = OpFunction %void None %14
996996
%16 = OpLabel
997997
%17 = OpAccessChain %_ptr_Uniform_type__Globals %_Globals %uint_0
998-
%18 = OpPtrAccessChain %_ptr_Uniform_float %17 %uint_1 %uint_0
998+
%18 = OpPtrAccessChain %_ptr_Uniform_float %17 %uint_0 %uint_0
999999
%19 = OpPtrAccessChain %_ptr_Uniform_float %17 %uint_0 %uint_2
10001000
OpReturn
10011001
OpFunctionEnd

test/val/val_decoration_test.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10304,6 +10304,7 @@ OpMemberDecorate %struct 0 Offset 0
1030410304
OpMemberDecorate %struct 1 Offset 4
1030510305
)" + set + R"(OpMemberDecorate %test_type 0 Offset 0
1030610306
OpMemberDecorate %test_type 1 Offset 1
10307+
OpDecorate %ptr ArrayStride 16
1030710308
%void = OpTypeVoid
1030810309
%int = OpTypeInt 32 0
1030910310
%int_0 = OpConstant %int 0
@@ -10312,7 +10313,8 @@ OpMemberDecorate %test_type 1 Offset 1
1031210313
%test_val = OpConstantNull %test_type
1031310314
%ptr = OpTypeUntypedPointerKHR )" +
1031410315
sc + R"(
10315-
%var = OpUntypedVariableKHR %ptr )" + sc + R"( %struct
10316+
%var = OpUntypedVariableKHR %ptr )" +
10317+
sc + R"( %struct
1031610318
%void_fn = OpTypeFunction %void
1031710319
%main = OpFunction %void None %void_fn
1031810320
%entry = OpLabel
@@ -10355,6 +10357,7 @@ OpDecorate %struct Block
1035510357
OpMemberDecorate %struct 0 Offset 0
1035610358
OpMemberDecorate %struct 1 Offset 4
1035710359
)" + set + R"(OpDecorate %test_type ArrayStride 4
10360+
OpDecorate %ptr ArrayStride 16
1035810361
%void = OpTypeVoid
1035910362
%int = OpTypeInt 32 0
1036010363
%int_0 = OpConstant %int 0
@@ -10365,7 +10368,8 @@ OpMemberDecorate %struct 1 Offset 4
1036510368
%struct = OpTypeStruct %int %int
1036610369
%ptr = OpTypeUntypedPointerKHR )" +
1036710370
sc + R"(
10368-
%var = OpUntypedVariableKHR %ptr )" + sc + R"( %struct
10371+
%var = OpUntypedVariableKHR %ptr )" +
10372+
sc + R"( %struct
1036910373
%void_fn = OpTypeFunction %void
1037010374
%main = OpFunction %void None %void_fn
1037110375
%entry = OpLabel

0 commit comments

Comments
 (0)