@@ -58,6 +58,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithIL(
5858 *Program // /< [out] pointer to handle of program object created.
5959) {
6060 std::ignore = Properties;
61+ UR_ASSERT (Context, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
62+ UR_ASSERT (IL && Program, UR_RESULT_ERROR_INVALID_NULL_POINTER);
6163 try {
6264 ur_program_handle_t_ *UrProgram =
6365 new ur_program_handle_t_ (ur_program_handle_t_::IL, Context, IL, Length);
@@ -82,8 +84,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
8284 ur_program_handle_t
8385 *Program // /< [out] pointer to handle of Program object created.
8486) {
85- std::ignore = Device;
86- std::ignore = Properties;
8787 // In OpenCL, clCreateProgramWithBinary() can be used to load any of the
8888 // following: "program executable", "compiled program", or "library of
8989 // compiled programs". In addition, the loaded program can be either
@@ -96,8 +96,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
9696 // information to distinguish the cases.
9797
9898 try {
99- ur_program_handle_t_ *UrProgram = new ur_program_handle_t_ (
100- ur_program_handle_t_::Native, Context, Binary, Size);
99+ ur_program_handle_t_ *UrProgram =
100+ new ur_program_handle_t_ (ur_program_handle_t_::Native, Context, Device,
101+ Properties, Binary, Size);
101102 *Program = reinterpret_cast <ur_program_handle_t >(UrProgram);
102103 } catch (const std::bad_alloc &) {
103104 return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
@@ -597,11 +598,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetGlobalVariablePointer(
597598 void **GlobalVariablePointerRet // /< [out] Returns the pointer to the global
598599 // /< variable if it is found in the program.
599600) {
600- std::ignore = Device;
601601 std::scoped_lock<ur_shared_mutex> lock (Program->Mutex );
602602
603+ ze_module_handle_t ZeModuleEntry{};
604+ ZeModuleEntry = Program->ZeModule ;
605+ if (!Program->ZeModuleMap .empty ()) {
606+ auto It = Program->ZeModuleMap .find (Device->ZeDevice );
607+ if (It != Program->ZeModuleMap .end ()) {
608+ ZeModuleEntry = It->second ;
609+ }
610+ }
611+
603612 ze_result_t ZeResult =
604- zeModuleGetGlobalPointer (Program-> ZeModule , GlobalVariableName,
613+ zeModuleGetGlobalPointer (ZeModuleEntry , GlobalVariableName,
605614 GlobalVariableSizeRet, GlobalVariablePointerRet);
606615
607616 if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_FEATURE) {
@@ -632,11 +641,28 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
632641 case UR_PROGRAM_INFO_CONTEXT:
633642 return ReturnValue (Program->Context );
634643 case UR_PROGRAM_INFO_NUM_DEVICES:
635- // TODO: return true number of devices this program exists for.
636- return ReturnValue (uint32_t {1 });
644+ if (!Program->ZeModuleMap .empty ())
645+ return ReturnValue (
646+ uint32_t {ur_cast<uint32_t >(Program->ZeModuleMap .size ())});
647+ else
648+ return ReturnValue (uint32_t {1 });
637649 case UR_PROGRAM_INFO_DEVICES:
638- // TODO: return all devices this program exists for.
639- return ReturnValue (Program->Context ->Devices [0 ]);
650+ if (!Program->ZeModuleMap .empty ()) {
651+ std::vector<ur_device_handle_t > devices;
652+ for (auto &ZeModulePair : Program->ZeModuleMap ) {
653+ auto It = Program->ZeModuleMap .find (ZeModulePair.first );
654+ if (It != Program->ZeModuleMap .end ()) {
655+ for (auto &Device : Program->Context ->Devices ) {
656+ if (Device->ZeDevice == ZeModulePair.first ) {
657+ devices.push_back (Device);
658+ }
659+ }
660+ }
661+ }
662+ return ReturnValue (devices.data (), devices.size ());
663+ } else {
664+ return ReturnValue (Program->Context ->Devices [0 ]);
665+ }
640666 case UR_PROGRAM_INFO_BINARY_SIZES: {
641667 std::shared_lock<ur_shared_mutex> Guard (Program->Mutex );
642668 size_t SzBinary;
@@ -645,8 +671,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
645671 Program->State == ur_program_handle_t_::Object) {
646672 SzBinary = Program->CodeLength ;
647673 } else if (Program->State == ur_program_handle_t_::Exe) {
648- ZE2UR_CALL (zeModuleGetNativeBinary,
649- (Program->ZeModule , &SzBinary, nullptr ));
674+ if (!Program->ZeModuleMap .empty ()) {
675+ std::vector<size_t > binarySizes;
676+ for (auto &ZeModulePair : Program->ZeModuleMap ) {
677+ size_t binarySize = 0 ;
678+ ZE2UR_CALL (zeModuleGetNativeBinary,
679+ (ZeModulePair.second , &binarySize, nullptr ));
680+ binarySizes.push_back (binarySize);
681+ }
682+ return ReturnValue (binarySizes.data (), binarySizes.size ());
683+ } else {
684+ ZE2UR_CALL (zeModuleGetNativeBinary,
685+ (Program->ZeModule , &SzBinary, nullptr ));
686+ return ReturnValue (SzBinary);
687+ }
650688 } else {
651689 return UR_RESULT_ERROR_INVALID_PROGRAM;
652690 }
@@ -655,38 +693,71 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
655693 }
656694 case UR_PROGRAM_INFO_BINARIES: {
657695 // The caller sets "ParamValue" to an array of pointers, one for each
658- // device. Since Level Zero supports only one device, there is only one
659- // pointer. If the pointer is NULL, we don't do anything. Otherwise, we
660- // copy the program's binary image to the buffer at that pointer.
661- uint8_t **PBinary = ur_cast<uint8_t **>(ProgramInfo);
662- if (!PBinary[0 ])
663- break ;
664-
696+ // device.
697+ uint8_t **PBinary = nullptr ;
698+ if (ProgramInfo) {
699+ PBinary = ur_cast<uint8_t **>(ProgramInfo);
700+ if (!PBinary[0 ]) {
701+ break ;
702+ }
703+ }
665704 std::shared_lock<ur_shared_mutex> Guard (Program->Mutex );
705+ // If the caller is using a Program which is IL, Native or an object, then
706+ // the program has not been built for multiple devices so a single IL is
707+ // returned.
666708 if (Program->State == ur_program_handle_t_::IL ||
667709 Program->State == ur_program_handle_t_::Native ||
668710 Program->State == ur_program_handle_t_::Object) {
669- std::memcpy (PBinary[0 ], Program->Code .get (), Program->CodeLength );
711+ if (PropSizeRet)
712+ *PropSizeRet = Program->CodeLength ;
713+ if (PBinary) {
714+ std::memcpy (PBinary[0 ], Program->Code .get (), Program->CodeLength );
715+ }
670716 } else if (Program->State == ur_program_handle_t_::Exe) {
717+ // If the caller is using a Program which is a built binary, then
718+ // the program returned will either be a single module if this is a native
719+ // binary or the native binary for each device will be returned.
671720 size_t SzBinary = 0 ;
672- ZE2UR_CALL (zeModuleGetNativeBinary,
673- (Program->ZeModule , &SzBinary, PBinary[0 ]));
721+ uint8_t *NativeBinaryPtr = nullptr ;
722+ if (PBinary) {
723+ NativeBinaryPtr = PBinary[0 ];
724+ }
725+ if (!Program->ZeModuleMap .empty ()) {
726+ uint32_t deviceIndex = 0 ;
727+ for (auto &ZeDeviceModule : Program->ZeModuleMap ) {
728+ size_t binarySize = 0 ;
729+ ZE2UR_CALL (
730+ zeModuleGetNativeBinary,
731+ (ZeDeviceModule.second , &binarySize, PBinary[deviceIndex++]));
732+ SzBinary += binarySize;
733+ }
734+ } else {
735+ ZE2UR_CALL (zeModuleGetNativeBinary,
736+ (Program->ZeModule , &SzBinary, NativeBinaryPtr));
737+ }
738+ if (PropSizeRet)
739+ *PropSizeRet = SzBinary;
674740 } else {
675741 return UR_RESULT_ERROR_INVALID_PROGRAM;
676742 }
677743 break ;
678744 }
679745 case UR_PROGRAM_INFO_NUM_KERNELS: {
680746 std::shared_lock<ur_shared_mutex> Guard (Program->Mutex );
681- uint32_t NumKernels;
747+ uint32_t NumKernels = 0 ;
682748 if (Program->State == ur_program_handle_t_::IL ||
683749 Program->State == ur_program_handle_t_::Native ||
684750 Program->State == ur_program_handle_t_::Object) {
685751 return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
686752 } else if (Program->State == ur_program_handle_t_::Exe) {
687- NumKernels = 0 ;
688- ZE2UR_CALL (zeModuleGetKernelNames,
689- (Program->ZeModule , &NumKernels, nullptr ));
753+ if (!Program->ZeModuleMap .empty ()) {
754+ ZE2UR_CALL (
755+ zeModuleGetKernelNames,
756+ (Program->ZeModuleMap .begin ()->second , &NumKernels, nullptr ));
757+ } else {
758+ ZE2UR_CALL (zeModuleGetKernelNames,
759+ (Program->ZeModule , &NumKernels, nullptr ));
760+ }
690761 } else {
691762 return UR_RESULT_ERROR_INVALID_PROGRAM;
692763 }
@@ -702,11 +773,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
702773 return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
703774 } else if (Program->State == ur_program_handle_t_::Exe) {
704775 uint32_t Count = 0 ;
705- ZE2UR_CALL (zeModuleGetKernelNames,
706- (Program->ZeModule , &Count, nullptr ));
707- std::unique_ptr<const char *[]> PNames (new const char *[Count]);
708- ZE2UR_CALL (zeModuleGetKernelNames,
709- (Program->ZeModule , &Count, PNames.get ()));
776+ std::unique_ptr<const char *[]> PNames;
777+ if (!Program->ZeModuleMap .empty ()) {
778+ ZE2UR_CALL (zeModuleGetKernelNames,
779+ (Program->ZeModuleMap .begin ()->second , &Count, nullptr ));
780+ PNames = std::make_unique<const char *[]>(Count);
781+ ZE2UR_CALL (
782+ zeModuleGetKernelNames,
783+ (Program->ZeModuleMap .begin ()->second , &Count, PNames.get ()));
784+ } else {
785+ ZE2UR_CALL (zeModuleGetKernelNames,
786+ (Program->ZeModule , &Count, nullptr ));
787+ PNames = std::make_unique<const char *[]>(Count);
788+ ZE2UR_CALL (zeModuleGetKernelNames,
789+ (Program->ZeModule , &Count, PNames.get ()));
790+ }
710791 for (uint32_t I = 0 ; I < Count; ++I) {
711792 PINames += (I > 0 ? " ;" : " " );
712793 PINames += PNames[I];
@@ -720,8 +801,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
720801 } catch (...) {
721802 return UR_RESULT_ERROR_UNKNOWN;
722803 }
804+ case UR_PROGRAM_INFO_SOURCE:
805+ return ReturnValue (Program->Code .get ());
723806 default :
724- die ( " urProgramGetInfo: not implemented " ) ;
807+ return UR_RESULT_ERROR_INVALID_ENUMERATION ;
725808 }
726809
727810 return UR_RESULT_SUCCESS;
@@ -761,6 +844,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetBuildInfo(
761844 // return for programs that were built outside and registered
762845 // with urProgramRegister?
763846 return ReturnValue (" " );
847+ } else if (PropName == UR_PROGRAM_BUILD_INFO_STATUS) {
848+ return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
764849 } else if (PropName == UR_PROGRAM_BUILD_INFO_LOG) {
765850 // Check first to see if the plugin code recorded an error message.
766851 if (!Program->ErrorMessage .empty ()) {
@@ -852,6 +937,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithNativeHandle(
852937 // /< program object created.
853938) {
854939 std::ignore = Properties;
940+ UR_ASSERT (Context && NativeProgram, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
941+ UR_ASSERT (Program, UR_RESULT_ERROR_INVALID_NULL_POINTER);
855942 auto ZeModule = ur_cast<ze_module_handle_t >(NativeProgram);
856943
857944 // We assume here that programs created from a native handle always
0 commit comments