Skip to content

Commit 7c77897

Browse files
author
LU-JOHN
authored
Allow fmt arg to printf to be an array of i8 in non-constant space (KhronosGroup#5677)
* In spirv-val allow format arg to printf to be an array of i8 in Generic space Signed-off-by: Lu, John <[email protected]> * Allow more addr spaces for printf format string Signed-off-by: Lu, John <[email protected]> * Update printf format arg testcase Signed-off-by: Lu, John <[email protected]> * Apply clang-format Signed-off-by: Lu, John <[email protected]> * Reorder code for clarity Signed-off-by: Lu, John <[email protected]> * Only allow other addr spaces if extension is seen Signed-off-by: Lu, John <[email protected]> * Add test to check printf format with extension Signed-off-by: Lu, John <[email protected]> * Add extension correctly Signed-off-by: Lu, John <[email protected]> --------- Signed-off-by: Lu, John <[email protected]>
1 parent 257cacf commit 7c77897

File tree

5 files changed

+69
-4
lines changed

5 files changed

+69
-4
lines changed

source/val/validate_extensions.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2962,12 +2962,38 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) {
29622962
<< "expected operand Format to be a pointer";
29632963
}
29642964

2965-
if (format_storage_class != spv::StorageClass::UniformConstant) {
2966-
return _.diag(SPV_ERROR_INVALID_DATA, inst)
2967-
<< ext_inst_name() << ": "
2968-
<< "expected Format storage class to be UniformConstant";
2965+
if (_.HasExtension(
2966+
Extension::kSPV_EXT_relaxed_printf_string_address_space)) {
2967+
if (format_storage_class != spv::StorageClass::UniformConstant &&
2968+
// Extension SPV_EXT_relaxed_printf_string_address_space allows
2969+
// format strings in Global, Local, Private and Generic address
2970+
// spaces
2971+
2972+
// Global
2973+
format_storage_class != spv::StorageClass::CrossWorkgroup &&
2974+
// Local
2975+
format_storage_class != spv::StorageClass::Workgroup &&
2976+
// Private
2977+
format_storage_class != spv::StorageClass::Function &&
2978+
// Generic
2979+
format_storage_class != spv::StorageClass::Generic) {
2980+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
2981+
<< ext_inst_name() << ": "
2982+
<< "expected Format storage class to be UniformConstant, "
2983+
"Crossworkgroup, Workgroup, Function, or Generic";
2984+
}
2985+
} else {
2986+
if (format_storage_class != spv::StorageClass::UniformConstant) {
2987+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
2988+
<< ext_inst_name() << ": "
2989+
<< "expected Format storage class to be UniformConstant";
2990+
}
29692991
}
29702992

2993+
// If pointer points to an array, get the type of an element
2994+
if (_.IsIntArrayType(format_data_type))
2995+
format_data_type = _.GetComponentType(format_data_type);
2996+
29712997
if (!_.IsIntScalarType(format_data_type) ||
29722998
_.GetBitWidth(format_data_type) != 8) {
29732999
return _.diag(SPV_ERROR_INVALID_DATA, inst)

source/val/validation_state.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,9 @@ uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
869869
case spv::Op::OpTypeBool:
870870
return id;
871871

872+
case spv::Op::OpTypeArray:
873+
return inst->word(2);
874+
872875
case spv::Op::OpTypeVector:
873876
return inst->word(2);
874877

@@ -992,6 +995,19 @@ bool ValidationState_t::IsIntScalarType(uint32_t id) const {
992995
return inst && inst->opcode() == spv::Op::OpTypeInt;
993996
}
994997

998+
bool ValidationState_t::IsIntArrayType(uint32_t id) const {
999+
const Instruction* inst = FindDef(id);
1000+
if (!inst) {
1001+
return false;
1002+
}
1003+
1004+
if (inst->opcode() == spv::Op::OpTypeArray) {
1005+
return IsIntScalarType(GetComponentType(id));
1006+
}
1007+
1008+
return false;
1009+
}
1010+
9951011
bool ValidationState_t::IsIntVectorType(uint32_t id) const {
9961012
const Instruction* inst = FindDef(id);
9971013
if (!inst) {

source/val/validation_state.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ class ValidationState_t {
606606
bool IsFloatScalarOrVectorType(uint32_t id) const;
607607
bool IsFloatMatrixType(uint32_t id) const;
608608
bool IsIntScalarType(uint32_t id) const;
609+
bool IsIntArrayType(uint32_t id) const;
609610
bool IsIntVectorType(uint32_t id) const;
610611
bool IsIntScalarOrVectorType(uint32_t id) const;
611612
bool IsUnsignedIntScalarType(uint32_t id) const;

test/val/val_ext_inst_test.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ OpCapability Matrix
447447
%u8arr_uniform_constant = OpVariable %u8arr_ptr_uniform_constant UniformConstant
448448
%u8_ptr_uniform_constant = OpTypePointer UniformConstant %u8
449449
%u8_ptr_generic = OpTypePointer Generic %u8
450+
%u8_ptr_input = OpTypePointer Input %u8
450451
451452
%main = OpFunction %void None %func
452453
%main_entry = OpLabel
@@ -5269,6 +5270,26 @@ TEST_F(ValidateExtInst, OpenCLStdPrintfFormatNotUniformConstStorageClass) {
52695270
"be UniformConstant"));
52705271
}
52715272

5273+
TEST_F(ValidateExtInst,
5274+
OpenCLStdPrintfFormatWithExtensionNotAllowedStorageClass) {
5275+
const std::string body = R"(
5276+
%format_const = OpAccessChain %u8_ptr_uniform_constant %u8arr_uniform_constant %u32_0
5277+
%format = OpBitcast %u8_ptr_input %format_const
5278+
%val1 = OpExtInst %u32 %extinst printf %format %u32_0 %u32_1
5279+
)";
5280+
5281+
const std::string extension = R"(
5282+
OpExtension "SPV_EXT_relaxed_printf_string_address_space"
5283+
)";
5284+
5285+
CompileSuccessfully(GenerateKernelCode(body, extension));
5286+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
5287+
EXPECT_THAT(getDiagnosticString(),
5288+
HasSubstr("OpenCL.std printf: expected Format storage class to "
5289+
"be UniformConstant, Crossworkgroup, Workgroup, "
5290+
"Function, or Generic"));
5291+
}
5292+
52725293
TEST_F(ValidateExtInst, OpenCLStdPrintfFormatNotU8Pointer) {
52735294
const std::string body = R"(
52745295
%format = OpAccessChain %u32_ptr_uniform_constant %u32vec8_uniform_constant %u32_0

utils/generate_grammar_tables.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
SPV_AMD_gpu_shader_int16
3232
SPV_AMD_shader_trinary_minmax
3333
SPV_KHR_non_semantic_info
34+
SPV_EXT_relaxed_printf_string_address_space
3435
"""
3536

3637
OUTPUT_LANGUAGE = 'c'

0 commit comments

Comments
 (0)