Skip to content

Commit a80d3b5

Browse files
authored
Fix untyped pointer comparison validation (KhronosGroup#6004)
* OpPtrEqual and OpPtrNotEqual both allow mixed operands * If both are typed they must be the same type * If either is untyped they must match storage class * OpPtrDiff must match operand types
1 parent 2e55f9c commit a80d3b5

File tree

2 files changed

+250
-5
lines changed

2 files changed

+250
-5
lines changed

source/val/validate_memory.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2781,17 +2781,42 @@ spv_result_t ValidatePtrComparison(ValidationState_t& _,
27812781

27822782
const auto op1 = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
27832783
const auto op2 = _.FindDef(inst->GetOperandAs<uint32_t>(3u));
2784-
if (!op1 || !op2 || op1->type_id() != op2->type_id()) {
2785-
return _.diag(SPV_ERROR_INVALID_ID, inst)
2786-
<< "The types of Operand 1 and Operand 2 must match";
2787-
}
27882784
const auto op1_type = _.FindDef(op1->type_id());
2785+
const auto op2_type = _.FindDef(op2->type_id());
27892786
if (!op1_type || (op1_type->opcode() != spv::Op::OpTypePointer &&
27902787
op1_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
27912788
return _.diag(SPV_ERROR_INVALID_ID, inst)
27922789
<< "Operand type must be a pointer";
27932790
}
27942791

2792+
if (!op2_type || (op2_type->opcode() != spv::Op::OpTypePointer &&
2793+
op2_type->opcode() != spv::Op::OpTypeUntypedPointerKHR)) {
2794+
return _.diag(SPV_ERROR_INVALID_ID, inst)
2795+
<< "Operand type must be a pointer";
2796+
}
2797+
2798+
if (inst->opcode() == spv::Op::OpPtrDiff) {
2799+
if (op1->type_id() != op2->type_id()) {
2800+
return _.diag(SPV_ERROR_INVALID_ID, inst)
2801+
<< "The types of Operand 1 and Operand 2 must match";
2802+
}
2803+
} else {
2804+
const auto either_untyped =
2805+
op1_type->opcode() == spv::Op::OpTypeUntypedPointerKHR ||
2806+
op2_type->opcode() == spv::Op::OpTypeUntypedPointerKHR;
2807+
if (either_untyped) {
2808+
const auto sc1 = op1_type->GetOperandAs<spv::StorageClass>(1);
2809+
const auto sc2 = op2_type->GetOperandAs<spv::StorageClass>(1);
2810+
if (sc1 != sc2) {
2811+
return _.diag(SPV_ERROR_INVALID_ID, inst)
2812+
<< "Pointer storage classes must match";
2813+
}
2814+
} else if (op1->type_id() != op2->type_id()) {
2815+
return _.diag(SPV_ERROR_INVALID_ID, inst)
2816+
<< "The types of Operand 1 and Operand 2 must match";
2817+
}
2818+
}
2819+
27952820
spv::StorageClass sc = op1_type->GetOperandAs<spv::StorageClass>(1u);
27962821
if (_.addressing_model() == spv::AddressingModel::Logical) {
27972822
if (sc != spv::StorageClass::Workgroup &&

test/val/val_memory_test.cpp

Lines changed: 221 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3883,8 +3883,11 @@ OpMemoryModel Logical GLSL450
38833883
%void = OpTypeVoid
38843884
%bool = OpTypeBool
38853885
%int = OpTypeInt 32 0
3886+
%float = OpTypeFloat 32
38863887
%ptr_int = OpTypePointer Private %int
38873888
%var = OpVariable %ptr_int Private
3889+
%ptr_float = OpTypePointer Private %float
3890+
%var2 = OpVariable %ptr_float Private
38883891
%func_ty = OpTypeFunction %void
38893892
%func = OpFunction %void None %func_ty
38903893
%1 = OpLabel
@@ -3897,7 +3900,7 @@ OpMemoryModel Logical GLSL450
38973900
spirv += " %bool ";
38983901
}
38993902

3900-
spirv += R"(%var %ld
3903+
spirv += R"(%var %var2
39013904
OpReturn
39023905
OpFunctionEnd
39033906
)";
@@ -3908,6 +3911,223 @@ OpFunctionEnd
39083911
HasSubstr("The types of Operand 1 and Operand 2 must match"));
39093912
}
39103913

3914+
TEST_P(ValidatePointerComparisons, GoodUntypedPointerSameType) {
3915+
const std::string operation = GetParam();
3916+
3917+
std::string spirv = R"(
3918+
OpCapability Shader
3919+
OpCapability Linkage
3920+
OpCapability VariablePointersStorageBuffer
3921+
OpCapability UntypedPointersKHR
3922+
OpExtension "SPV_KHR_untyped_pointers"
3923+
OpMemoryModel Logical GLSL450
3924+
%void = OpTypeVoid
3925+
%bool = OpTypeBool
3926+
%int = OpTypeInt 32 0
3927+
%ptr = OpTypeUntypedPointerKHR StorageBuffer
3928+
%var = OpUntypedVariableKHR %ptr StorageBuffer
3929+
%func_ty = OpTypeFunction %void
3930+
%func = OpFunction %void None %func_ty
3931+
%1 = OpLabel
3932+
%equal = )" + operation;
3933+
3934+
if (operation == "OpPtrDiff") {
3935+
spirv += " %int ";
3936+
} else {
3937+
spirv += " %bool ";
3938+
}
3939+
3940+
spirv += R"(%var %var
3941+
OpReturn
3942+
OpFunctionEnd
3943+
)";
3944+
3945+
CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_4);
3946+
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_4));
3947+
}
3948+
3949+
TEST_P(ValidatePointerComparisons, GoodUntypedPointerSameStorageClass) {
3950+
const std::string operation = GetParam();
3951+
3952+
std::string spirv = R"(
3953+
OpCapability Shader
3954+
OpCapability Linkage
3955+
OpCapability VariablePointersStorageBuffer
3956+
OpCapability UntypedPointersKHR
3957+
OpExtension "SPV_KHR_untyped_pointers"
3958+
OpMemoryModel Logical GLSL450
3959+
%void = OpTypeVoid
3960+
%bool = OpTypeBool
3961+
%int = OpTypeInt 32 0
3962+
%ptr1 = OpTypeUntypedPointerKHR StorageBuffer
3963+
%var = OpUntypedVariableKHR %ptr1 StorageBuffer
3964+
%ptr2 = OpTypeUntypedPointerKHR StorageBuffer
3965+
%var2 = OpUntypedVariableKHR %ptr2 StorageBuffer
3966+
%func_ty = OpTypeFunction %void
3967+
%func = OpFunction %void None %func_ty
3968+
%1 = OpLabel
3969+
%equal = )" + operation;
3970+
3971+
if (operation == "OpPtrDiff") {
3972+
spirv += " %int ";
3973+
} else {
3974+
spirv += " %bool ";
3975+
}
3976+
3977+
spirv += R"(%var %var2
3978+
OpReturn
3979+
OpFunctionEnd
3980+
)";
3981+
3982+
CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_4);
3983+
if (operation == "OpPtrDiff") {
3984+
EXPECT_EQ(SPV_ERROR_INVALID_ID,
3985+
ValidateInstructions(SPV_ENV_UNIVERSAL_1_4));
3986+
EXPECT_THAT(getDiagnosticString(),
3987+
HasSubstr("The types of Operand 1 and Operand 2 must match"));
3988+
} else {
3989+
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_4));
3990+
}
3991+
}
3992+
3993+
TEST_P(ValidatePointerComparisons, BadUntypedPointerDiffStorageClass) {
3994+
const std::string operation = GetParam();
3995+
3996+
std::string spirv = R"(
3997+
OpCapability Shader
3998+
OpCapability Linkage
3999+
OpCapability VariablePointers
4000+
OpCapability UntypedPointersKHR
4001+
OpExtension "SPV_KHR_untyped_pointers"
4002+
OpMemoryModel Logical GLSL450
4003+
%void = OpTypeVoid
4004+
%bool = OpTypeBool
4005+
%int = OpTypeInt 32 0
4006+
%ptr1 = OpTypeUntypedPointerKHR StorageBuffer
4007+
%var1 = OpUntypedVariableKHR %ptr1 StorageBuffer
4008+
%ptr2 = OpTypeUntypedPointerKHR Workgroup
4009+
%var2 = OpUntypedVariableKHR %ptr2 Workgroup %int
4010+
%func_ty = OpTypeFunction %void
4011+
%func = OpFunction %void None %func_ty
4012+
%1 = OpLabel
4013+
%equal = )" + operation;
4014+
4015+
if (operation == "OpPtrDiff") {
4016+
spirv += " %int ";
4017+
} else {
4018+
spirv += " %bool ";
4019+
}
4020+
4021+
spirv += R"(%var1 %var2
4022+
OpReturn
4023+
OpFunctionEnd
4024+
)";
4025+
4026+
CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_4);
4027+
if (operation == "OpPtrDiff") {
4028+
EXPECT_EQ(SPV_ERROR_INVALID_ID,
4029+
ValidateInstructions(SPV_ENV_UNIVERSAL_1_4));
4030+
EXPECT_THAT(getDiagnosticString(),
4031+
HasSubstr("The types of Operand 1 and Operand 2 must match"));
4032+
} else {
4033+
EXPECT_EQ(SPV_ERROR_INVALID_ID,
4034+
ValidateInstructions(SPV_ENV_UNIVERSAL_1_4));
4035+
EXPECT_THAT(getDiagnosticString(),
4036+
HasSubstr("Pointer storage classes must match"));
4037+
}
4038+
}
4039+
4040+
TEST_P(ValidatePointerComparisons, GoodMixedPointerSameStorageClass) {
4041+
const std::string operation = GetParam();
4042+
4043+
std::string spirv = R"(
4044+
OpCapability Shader
4045+
OpCapability Linkage
4046+
OpCapability VariablePointersStorageBuffer
4047+
OpCapability UntypedPointersKHR
4048+
OpExtension "SPV_KHR_untyped_pointers"
4049+
OpMemoryModel Logical GLSL450
4050+
%void = OpTypeVoid
4051+
%bool = OpTypeBool
4052+
%int = OpTypeInt 32 0
4053+
%ptr1 = OpTypeUntypedPointerKHR StorageBuffer
4054+
%var = OpUntypedVariableKHR %ptr1 StorageBuffer
4055+
%ptr2 = OpTypePointer StorageBuffer %int
4056+
%var2 = OpVariable %ptr2 StorageBuffer
4057+
%func_ty = OpTypeFunction %void
4058+
%func = OpFunction %void None %func_ty
4059+
%1 = OpLabel
4060+
%equal = )" + operation;
4061+
4062+
if (operation == "OpPtrDiff") {
4063+
spirv += " %int ";
4064+
} else {
4065+
spirv += " %bool ";
4066+
}
4067+
4068+
spirv += R"(%var %var2
4069+
OpReturn
4070+
OpFunctionEnd
4071+
)";
4072+
4073+
CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_4);
4074+
if (operation == "OpPtrDiff") {
4075+
EXPECT_EQ(SPV_ERROR_INVALID_ID,
4076+
ValidateInstructions(SPV_ENV_UNIVERSAL_1_4));
4077+
EXPECT_THAT(getDiagnosticString(),
4078+
HasSubstr("The types of Operand 1 and Operand 2 must match"));
4079+
} else {
4080+
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_4));
4081+
}
4082+
}
4083+
4084+
TEST_P(ValidatePointerComparisons, BadMixedPointerDiffStorageClass) {
4085+
const std::string operation = GetParam();
4086+
4087+
std::string spirv = R"(
4088+
OpCapability Shader
4089+
OpCapability Linkage
4090+
OpCapability VariablePointers
4091+
OpCapability UntypedPointersKHR
4092+
OpExtension "SPV_KHR_untyped_pointers"
4093+
OpMemoryModel Logical GLSL450
4094+
%void = OpTypeVoid
4095+
%bool = OpTypeBool
4096+
%int = OpTypeInt 32 0
4097+
%ptr1 = OpTypeUntypedPointerKHR StorageBuffer
4098+
%var1 = OpUntypedVariableKHR %ptr1 StorageBuffer
4099+
%ptr2 = OpTypePointer Workgroup %int
4100+
%var2 = OpVariable %ptr2 Workgroup
4101+
%func_ty = OpTypeFunction %void
4102+
%func = OpFunction %void None %func_ty
4103+
%1 = OpLabel
4104+
%equal = )" + operation;
4105+
4106+
if (operation == "OpPtrDiff") {
4107+
spirv += " %int ";
4108+
} else {
4109+
spirv += " %bool ";
4110+
}
4111+
4112+
spirv += R"(%var1 %var2
4113+
OpReturn
4114+
OpFunctionEnd
4115+
)";
4116+
4117+
CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_4);
4118+
if (operation == "OpPtrDiff") {
4119+
EXPECT_EQ(SPV_ERROR_INVALID_ID,
4120+
ValidateInstructions(SPV_ENV_UNIVERSAL_1_4));
4121+
EXPECT_THAT(getDiagnosticString(),
4122+
HasSubstr("The types of Operand 1 and Operand 2 must match"));
4123+
} else {
4124+
EXPECT_EQ(SPV_ERROR_INVALID_ID,
4125+
ValidateInstructions(SPV_ENV_UNIVERSAL_1_4));
4126+
EXPECT_THAT(getDiagnosticString(),
4127+
HasSubstr("Pointer storage classes must match"));
4128+
}
4129+
}
4130+
39114131
INSTANTIATE_TEST_SUITE_P(PointerComparisons, ValidatePointerComparisons,
39124132
Values("OpPtrEqual", "OpPtrNotEqual", "OpPtrDiff"));
39134133

0 commit comments

Comments
 (0)