@@ -55,6 +55,26 @@ checkUnresolvedSymbols(ze_module_handle_t ZeModule,
5555}
5656} // extern "C"
5757
58+ static ur_program_handle_t_::CodeFormat matchILCodeFormat (const void *Input,
59+ size_t Length) {
60+ const auto MatchMagicNumber = [&](uint32_t Number) {
61+ return Length >= sizeof (Number) &&
62+ std::memcmp (Input, &Number, sizeof (Number)) == 0 ;
63+ };
64+
65+ // SPIR-V Specification: 3.1 Magic Number
66+ // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#Magic
67+ if (MatchMagicNumber (0x07230203 )) {
68+ return ur_program_handle_t_::CodeFormat::SPIRV;
69+ }
70+
71+ return ur_program_handle_t_::CodeFormat::Unknown;
72+ }
73+
74+ static bool isCodeFormatIL (ur_program_handle_t_::CodeFormat CodeFormat) {
75+ return CodeFormat == ur_program_handle_t_::CodeFormat::SPIRV;
76+ }
77+
5878namespace ur ::level_zero {
5979
6080ur_result_t urProgramCreateWithIL (
@@ -70,9 +90,12 @@ ur_result_t urProgramCreateWithIL(
7090 ur_program_handle_t *Program) {
7191 UR_ASSERT (Context, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
7292 UR_ASSERT (IL && Program, UR_RESULT_ERROR_INVALID_NULL_POINTER);
93+ const ur_program_handle_t_::CodeFormat CodeFormat =
94+ matchILCodeFormat (IL, Length);
95+ UR_ASSERT (isCodeFormatIL (CodeFormat), UR_RESULT_ERROR_INVALID_BINARY);
7396 try {
74- ur_program_handle_t_ *UrProgram =
75- new ur_program_handle_t_ (ur_program_handle_t_ ::IL, Context, IL, Length);
97+ ur_program_handle_t_ *UrProgram = new ur_program_handle_t_ (
98+ ur_program_handle_t_::IL, Context, IL, Length, CodeFormat );
7699 *Program = reinterpret_cast <ur_program_handle_t >(UrProgram);
77100 } catch (const std::bad_alloc &) {
78101 return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
@@ -195,9 +218,17 @@ ur_result_t urProgramBuildExp(
195218 auto Code = hProgram->getCode (ZeDevice);
196219 UR_ASSERT (Code, UR_RESULT_ERROR_INVALID_PROGRAM);
197220
198- ZeModuleDesc.format = (State == ur_program_handle_t_::IL)
199- ? ZE_MODULE_FORMAT_IL_SPIRV
200- : ZE_MODULE_FORMAT_NATIVE;
221+ switch (hProgram->getCodeFormat (ZeDevice)) {
222+ case ur_program_handle_t_::CodeFormat::SPIRV:
223+ ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
224+ break ;
225+ case ur_program_handle_t_::CodeFormat::Native:
226+ ZeModuleDesc.format = ZE_MODULE_FORMAT_NATIVE;
227+ break ;
228+ default :
229+ ur::unreachable ();
230+ return UR_RESULT_ERROR_INVALID_PROGRAM;
231+ }
201232 ZeModuleDesc.inputSize = hProgram->getCodeSize (ZeDevice);
202233 ZeModuleDesc.pInputModule = Code;
203234 ze_context_handle_t ZeContext = hProgram->Context ->getZeHandle ();
@@ -364,6 +395,8 @@ ur_result_t urProgramLinkExp(
364395 // locks simultaneously with "exclusive" access. However, there is no such
365396 // code like that, so this is also not a danger.
366397 std::vector<std::shared_lock<ur_shared_mutex>> Guards (count);
398+ const ur_program_handle_t_::CodeFormat CommonCodeFormat =
399+ phPrograms[0 ]->getCodeFormat ();
367400 for (uint32_t I = 0 ; I < count; I++) {
368401 std::shared_lock<ur_shared_mutex> Guard (phPrograms[I]->Mutex );
369402 Guards[I].swap (Guard);
@@ -374,6 +407,13 @@ ur_result_t urProgramLinkExp(
374407 return UR_RESULT_ERROR_INVALID_OPERATION;
375408 }
376409 }
410+
411+ // The L0 API has no way to represent mixed format modules,
412+ // even though it could be possible to implement linking
413+ // of mixed format modules.
414+ if (phPrograms[I]->getCodeFormat () != CommonCodeFormat) {
415+ return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
416+ }
377417 }
378418
379419 // Previous calls to urProgramCompile did not actually compile the SPIR-V.
@@ -406,7 +446,14 @@ ur_result_t urProgramLinkExp(
406446
407447 ZeStruct<ze_module_desc_t > ZeModuleDesc;
408448 ZeModuleDesc.pNext = &ZeExtModuleDesc;
409- ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
449+ switch (CommonCodeFormat) {
450+ case ur_program_handle_t_::CodeFormat::SPIRV:
451+ ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
452+ break ;
453+ default :
454+ ur::unreachable ();
455+ return UR_RESULT_ERROR_INVALID_PROGRAM;
456+ }
410457
411458 // This works around a bug in the Level Zero driver. When "ZE_DEBUG=-1",
412459 // the driver does validation of the API calls, and it expects
@@ -996,11 +1043,13 @@ ur_result_t urProgramSetSpecializationConstants(
9961043
9971044ur_program_handle_t_::ur_program_handle_t_ (state St,
9981045 ur_context_handle_t Context,
999- const void *Input, size_t Length)
1046+ const void *Input, size_t Length,
1047+ CodeFormat CodeFormat)
10001048 : Context{Context}, NativeProperties{nullptr }, OwnZeModule{true },
1001- AssociatedDevices (Context->getDevices ()), SpirvCode{new uint8_t [Length]},
1002- SpirvCodeLength{Length} {
1003- std::memcpy (SpirvCode.get (), Input, Length);
1049+ AssociatedDevices (Context->getDevices ()), ILCode{new uint8_t [Length]},
1050+ ILCodeLength{Length}, ILCodeFormat(CodeFormat) {
1051+ assert (isCodeFormatIL (CodeFormat));
1052+ std::memcpy (ILCode.get (), Input, Length);
10041053 // All devices have the program in IL state.
10051054 for (auto &Device : Context->getDevices ()) {
10061055 DeviceData &PerDevData = DeviceDataMap[Device->ZeDevice ];
0 commit comments