Skip to content

Commit bc330e0

Browse files
[UR][L0] urProgramSetSpecializationConstants to returns error
Now urProgramSpcializationConstants will return UR_RESULT_ERROR_INVALID_SPEC_ID when the incorrect id is used. Signed-off-by: Zhang, Winston <[email protected]>
1 parent 5c0a51d commit bc330e0

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

unified-runtime/source/adapters/level_zero/program.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,67 @@ ur_result_t urProgramCreateWithNativeHandle(
994994
return UR_RESULT_SUCCESS;
995995
}
996996

997+
// Helper function to validate if a specialization constant ID exists in SPIR-V
998+
static bool isValidSpecConstantId(const uint8_t *spirvCode, size_t spirvSize,
999+
uint32_t specId) {
1000+
if (!spirvCode || spirvSize < 20) {
1001+
return false; // Invalid SPIR-V
1002+
}
1003+
1004+
// Check SPIR-V magic number
1005+
const uint32_t *words = reinterpret_cast<const uint32_t *>(spirvCode);
1006+
if (words[0] != 0x07230203) {
1007+
return false; // Invalid SPIR-V magic number
1008+
}
1009+
1010+
// Parse SPIR-V header
1011+
// words[0] = magic number
1012+
// words[1] = version
1013+
// words[2] = generator magic number
1014+
// words[3] = bound on all ids
1015+
// words[4] = schema (0)
1016+
size_t headerSize = 5;
1017+
if (spirvSize < headerSize * sizeof(uint32_t)) {
1018+
return false;
1019+
}
1020+
1021+
// Parse instructions looking for OpSpecConstant* instructions
1022+
size_t pos = headerSize;
1023+
const uint32_t *end = words + (spirvSize / sizeof(uint32_t));
1024+
1025+
while (pos < (spirvSize / sizeof(uint32_t))) {
1026+
if (pos >= (end - words))
1027+
break;
1028+
1029+
uint32_t instruction = words[pos];
1030+
uint16_t opcode = instruction & 0xFFFF;
1031+
uint16_t length = instruction >> 16;
1032+
1033+
if (length == 0 || pos + length > (end - words)) {
1034+
break; // Invalid instruction
1035+
}
1036+
1037+
// OpSpecConstantTrue = 48, OpSpecConstantFalse = 49, OpSpecConstant = 50
1038+
// OpSpecConstantComposite = 51, OpSpecConstantOp = 52
1039+
if (opcode >= 48 && opcode <= 52) {
1040+
if (length >=
1041+
3) { // All OpSpecConstant* instructions have at least 3 words
1042+
// words[pos + 0] = instruction header
1043+
// words[pos + 1] = result type id
1044+
// words[pos + 2] = result id (this is the specialization constant id)
1045+
uint32_t resultId = words[pos + 2];
1046+
if (resultId == specId) {
1047+
return true;
1048+
}
1049+
}
1050+
}
1051+
1052+
pos += length;
1053+
}
1054+
1055+
return false;
1056+
}
1057+
9971058
ur_result_t urProgramSetSpecializationConstants(
9981059
/// [in] handle of the Program object
9991060
ur_program_handle_t Program,
@@ -1004,6 +1065,17 @@ ur_result_t urProgramSetSpecializationConstants(
10041065
const ur_specialization_constant_info_t *SpecConstants) {
10051066
std::scoped_lock<ur_shared_mutex> Guard(Program->Mutex);
10061067

1068+
// Validate each specialization constant ID against the SPIR-V program
1069+
for (uint32_t SpecIt = 0; SpecIt < Count; SpecIt++) {
1070+
uint32_t SpecId = SpecConstants[SpecIt].id;
1071+
1072+
// Validate the spec constant ID exists in the SPIR-V binary
1073+
if (!isValidSpecConstantId(Program->getCode(), Program->getCodeSize(),
1074+
SpecId)) {
1075+
return UR_RESULT_ERROR_INVALID_SPEC_ID;
1076+
}
1077+
}
1078+
10071079
// Remember the value of this specialization constant until the program is
10081080
// built. Note that we only save the pointer to the buffer that contains the
10091081
// value. The caller is responsible for maintaining storage for this buffer.

0 commit comments

Comments
 (0)