@@ -599,11 +599,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetGlobalVariablePointer(
599599 void **GlobalVariablePointerRet // /< [out] Returns the pointer to the global
600600 // /< variable if it is found in the program.
601601) {
602- std::ignore = Device;
603602 std::scoped_lock<ur_shared_mutex> lock (Program->Mutex );
604603
604+ ze_module_handle_t ZeModuleEntry{};
605+ ZeModuleEntry = Program->ZeModule ;
606+ if (!Program->ZeModuleMap .empty ()) {
607+ auto It = Program->ZeModuleMap .find (Device->ZeDevice );
608+ if (It != Program->ZeModuleMap .end ()) {
609+ ZeModuleEntry = It->second ;
610+ }
611+ }
612+
605613 ze_result_t ZeResult =
606- zeModuleGetGlobalPointer (Program-> ZeModule , GlobalVariableName,
614+ zeModuleGetGlobalPointer (ZeModuleEntry , GlobalVariableName,
607615 GlobalVariableSizeRet, GlobalVariablePointerRet);
608616
609617 if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_FEATURE) {
@@ -634,11 +642,28 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
634642 case UR_PROGRAM_INFO_CONTEXT:
635643 return ReturnValue (Program->Context );
636644 case UR_PROGRAM_INFO_NUM_DEVICES:
637- // TODO: return true number of devices this program exists for.
638- return ReturnValue (uint32_t {1 });
645+ if (!Program->ZeModuleMap .empty ())
646+ return ReturnValue (
647+ uint32_t {ur_cast<uint32_t >(Program->ZeModuleMap .size ())});
648+ else
649+ return ReturnValue (uint32_t {1 });
639650 case UR_PROGRAM_INFO_DEVICES:
640- // TODO: return all devices this program exists for.
641- return ReturnValue (Program->Context ->Devices [0 ]);
651+ if (!Program->ZeModuleMap .empty ()) {
652+ std::vector<ur_device_handle_t > devices;
653+ for (auto &ZeModulePair : Program->ZeModuleMap ) {
654+ auto It = Program->ZeModuleMap .find (ZeModulePair.first );
655+ if (It != Program->ZeModuleMap .end ()) {
656+ for (auto &Device : Program->Context ->Devices ) {
657+ if (Device->ZeDevice == ZeModulePair.first ) {
658+ devices.push_back (Device);
659+ }
660+ }
661+ }
662+ }
663+ return ReturnValue (devices);
664+ } else {
665+ return ReturnValue (Program->Context ->Devices [0 ]);
666+ }
642667 case UR_PROGRAM_INFO_BINARY_SIZES: {
643668 std::shared_lock<ur_shared_mutex> Guard (Program->Mutex );
644669 size_t SzBinary;
@@ -647,8 +672,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
647672 Program->State == ur_program_handle_t_::Object) {
648673 SzBinary = Program->CodeLength ;
649674 } else if (Program->State == ur_program_handle_t_::Exe) {
650- ZE2UR_CALL (zeModuleGetNativeBinary,
651- (Program->ZeModule , &SzBinary, nullptr ));
675+ if (!Program->ZeModuleMap .empty ()) {
676+ std::vector<size_t > binarySizes;
677+ for (auto &ZeModulePair : Program->ZeModuleMap ) {
678+ size_t binarySize = 0 ;
679+ ZE2UR_CALL (zeModuleGetNativeBinary,
680+ (ZeModulePair.second , &binarySize, nullptr ));
681+ binarySizes.push_back (binarySize);
682+ }
683+ return ReturnValue (binarySizes);
684+ } else {
685+ ZE2UR_CALL (zeModuleGetNativeBinary,
686+ (Program->ZeModule , &SzBinary, nullptr ));
687+ return ReturnValue (SzBinary);
688+ }
652689 } else {
653690 return UR_RESULT_ERROR_INVALID_PROGRAM;
654691 }
@@ -657,9 +694,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
657694 }
658695 case UR_PROGRAM_INFO_BINARIES: {
659696 // The caller sets "ParamValue" to an array of pointers, one for each
660- // device. Since Level Zero supports only one device, there is only one
661- // pointer. If the pointer is NULL, we don't do anything. Otherwise, we
662- // copy the program's binary image to the buffer at that pointer.
697+ // device.
663698 uint8_t **PBinary = nullptr ;
664699 if (ProgramInfo) {
665700 PBinary = ur_cast<uint8_t **>(ProgramInfo);
@@ -668,6 +703,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
668703 }
669704 }
670705 std::shared_lock<ur_shared_mutex> Guard (Program->Mutex );
706+ // If the caller is using a Program which is IL, Native or an object, then
707+ // the program has not been built for multiple devices so a single IL is
708+ // returned.
671709 if (Program->State == ur_program_handle_t_::IL ||
672710 Program->State == ur_program_handle_t_::Native ||
673711 Program->State == ur_program_handle_t_::Object) {
@@ -677,13 +715,27 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
677715 std::memcpy (PBinary[0 ], Program->Code .get (), Program->CodeLength );
678716 }
679717 } else if (Program->State == ur_program_handle_t_::Exe) {
718+ // If the caller is using a Program which is a built binary, then
719+ // the program returned will either be a single module if this is a native
720+ // binary or the native binary for each device will be returned.
680721 size_t SzBinary = 0 ;
681722 uint8_t *NativeBinaryPtr = nullptr ;
682723 if (PBinary) {
683724 NativeBinaryPtr = PBinary[0 ];
684725 }
685- ZE2UR_CALL (zeModuleGetNativeBinary,
686- (Program->ZeModule , &SzBinary, NativeBinaryPtr));
726+ if (!Program->ZeModuleMap .empty ()) {
727+ uint32_t deviceIndex = 0 ;
728+ for (auto &ZeDeviceModule : Program->ZeModuleMap ) {
729+ size_t binarySize = 0 ;
730+ ZE2UR_CALL (
731+ zeModuleGetNativeBinary,
732+ (ZeDeviceModule.second , &binarySize, PBinary[deviceIndex++]));
733+ SzBinary += binarySize;
734+ }
735+ } else {
736+ ZE2UR_CALL (zeModuleGetNativeBinary,
737+ (Program->ZeModule , &SzBinary, NativeBinaryPtr));
738+ }
687739 if (PropSizeRet)
688740 *PropSizeRet = SzBinary;
689741 } else {
@@ -693,15 +745,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
693745 }
694746 case UR_PROGRAM_INFO_NUM_KERNELS: {
695747 std::shared_lock<ur_shared_mutex> Guard (Program->Mutex );
696- uint32_t NumKernels;
748+ uint32_t NumKernels = 0 ;
697749 if (Program->State == ur_program_handle_t_::IL ||
698750 Program->State == ur_program_handle_t_::Native ||
699751 Program->State == ur_program_handle_t_::Object) {
700752 return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
701753 } else if (Program->State == ur_program_handle_t_::Exe) {
702- NumKernels = 0 ;
703- ZE2UR_CALL (zeModuleGetKernelNames,
704- (Program->ZeModule , &NumKernels, nullptr ));
754+ if (!Program->ZeModuleMap .empty ()) {
755+ ZE2UR_CALL (
756+ zeModuleGetKernelNames,
757+ (Program->ZeModuleMap .begin ()->second , &NumKernels, nullptr ));
758+ } else {
759+ ZE2UR_CALL (zeModuleGetKernelNames,
760+ (Program->ZeModule , &NumKernels, nullptr ));
761+ }
705762 } else {
706763 return UR_RESULT_ERROR_INVALID_PROGRAM;
707764 }
@@ -717,11 +774,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo(
717774 return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
718775 } else if (Program->State == ur_program_handle_t_::Exe) {
719776 uint32_t Count = 0 ;
720- ZE2UR_CALL (zeModuleGetKernelNames,
721- (Program->ZeModule , &Count, nullptr ));
722- std::unique_ptr<const char *[]> PNames (new const char *[Count]);
723- ZE2UR_CALL (zeModuleGetKernelNames,
724- (Program->ZeModule , &Count, PNames.get ()));
777+ std::unique_ptr<const char *[]> PNames;
778+ if (!Program->ZeModuleMap .empty ()) {
779+ ZE2UR_CALL (zeModuleGetKernelNames,
780+ (Program->ZeModuleMap .begin ()->second , &Count, nullptr ));
781+ PNames = std::make_unique<const char *[]>(Count);
782+ ZE2UR_CALL (
783+ zeModuleGetKernelNames,
784+ (Program->ZeModuleMap .begin ()->second , &Count, PNames.get ()));
785+ } else {
786+ ZE2UR_CALL (zeModuleGetKernelNames,
787+ (Program->ZeModule , &Count, nullptr ));
788+ PNames = std::make_unique<const char *[]>(Count);
789+ ZE2UR_CALL (zeModuleGetKernelNames,
790+ (Program->ZeModule , &Count, PNames.get ()));
791+ }
725792 for (uint32_t I = 0 ; I < Count; ++I) {
726793 PINames += (I > 0 ? " ;" : " " );
727794 PINames += PNames[I];
0 commit comments