@@ -994,6 +994,67 @@ ur_result_t urProgramCreateWithNativeHandle(
994994 return UR_RESULT_SUCCESS;
995995}
996996
997+ // Helper function to validate if a specialization constant ID exists in SPIR-V
998+ static bool isValidSpecConstantId (const uint8_t *spirvCode, size_t spirvSize,
999+ uint32_t specId) {
1000+ if (!spirvCode || spirvSize < 20 ) {
1001+ return false ; // Invalid SPIR-V
1002+ }
1003+
1004+ // Check SPIR-V magic number
1005+ const uint32_t *words = reinterpret_cast <const uint32_t *>(spirvCode);
1006+ if (words[0 ] != 0x07230203 ) {
1007+ return false ; // Invalid SPIR-V magic number
1008+ }
1009+
1010+ // Parse SPIR-V header
1011+ // words[0] = magic number
1012+ // words[1] = version
1013+ // words[2] = generator magic number
1014+ // words[3] = bound on all ids
1015+ // words[4] = schema (0)
1016+ size_t headerSize = 5 ;
1017+ if (spirvSize < headerSize * sizeof (uint32_t )) {
1018+ return false ;
1019+ }
1020+
1021+ // Parse instructions looking for OpSpecConstant* instructions
1022+ size_t pos = headerSize;
1023+ const uint32_t *end = words + (spirvSize / sizeof (uint32_t ));
1024+
1025+ while (pos < (spirvSize / sizeof (uint32_t ))) {
1026+ if (pos >= (end - words))
1027+ break ;
1028+
1029+ uint32_t instruction = words[pos];
1030+ uint16_t opcode = instruction & 0xFFFF ;
1031+ uint16_t length = instruction >> 16 ;
1032+
1033+ if (length == 0 || pos + length > (end - words)) {
1034+ break ; // Invalid instruction
1035+ }
1036+
1037+ // OpSpecConstantTrue = 48, OpSpecConstantFalse = 49, OpSpecConstant = 50
1038+ // OpSpecConstantComposite = 51, OpSpecConstantOp = 52
1039+ if (opcode >= 48 && opcode <= 52 ) {
1040+ if (length >=
1041+ 3 ) { // All OpSpecConstant* instructions have at least 3 words
1042+ // words[pos + 0] = instruction header
1043+ // words[pos + 1] = result type id
1044+ // words[pos + 2] = result id (this is the specialization constant id)
1045+ uint32_t resultId = words[pos + 2 ];
1046+ if (resultId == specId) {
1047+ return true ;
1048+ }
1049+ }
1050+ }
1051+
1052+ pos += length;
1053+ }
1054+
1055+ return false ;
1056+ }
1057+
9971058ur_result_t urProgramSetSpecializationConstants (
9981059 // / [in] handle of the Program object
9991060 ur_program_handle_t Program,
@@ -1004,6 +1065,17 @@ ur_result_t urProgramSetSpecializationConstants(
10041065 const ur_specialization_constant_info_t *SpecConstants) {
10051066 std::scoped_lock<ur_shared_mutex> Guard (Program->Mutex );
10061067
1068+ // Validate each specialization constant ID against the SPIR-V program
1069+ for (uint32_t SpecIt = 0 ; SpecIt < Count; SpecIt++) {
1070+ uint32_t SpecId = SpecConstants[SpecIt].id ;
1071+
1072+ // Validate the spec constant ID exists in the SPIR-V binary
1073+ if (!isValidSpecConstantId (Program->getCode (), Program->getCodeSize (),
1074+ SpecId)) {
1075+ return UR_RESULT_ERROR_INVALID_SPEC_ID;
1076+ }
1077+ }
1078+
10071079 // Remember the value of this specialization constant until the program is
10081080 // built. Note that we only save the pointer to the buffer that contains the
10091081 // value. The caller is responsible for maintaining storage for this buffer.
0 commit comments