1010#include " shared/source/built_ins/built_ins.h"
1111#include " shared/source/built_ins/sip.h"
1212#include " shared/source/execution_environment/root_device_environment.h"
13+ #include " shared/source/helpers/affinity_mask.h"
1314#include " shared/source/helpers/hw_helper.h"
1415#include " shared/source/memory_manager/memory_manager.h"
1516#include " shared/source/memory_manager/os_agnostic_memory_manager.h"
@@ -91,33 +92,31 @@ void ExecutionEnvironment::parseAffinityMask() {
9192 return ;
9293 }
9394
94- std::vector<std::vector<bool >> affinityMaskBitSet (rootDeviceEnvironments.size ());
95- for (uint32_t i = 0 ; i < affinityMaskBitSet.size (); i++) {
96- auto hwInfo = rootDeviceEnvironments[i]->getHardwareInfo ();
97- affinityMaskBitSet[i].resize (HwHelper::getSubDevicesCount (hwInfo));
98- }
95+ const uint32_t numRootDevices = static_cast <uint32_t >(rootDeviceEnvironments.size ());
96+
97+ std::vector<AffinityMaskHelper> affinityMaskHelper (numRootDevices);
9998
10099 size_t pos = 0 ;
101100 while (pos < affinityMaskString.size ()) {
102101 size_t posNextDot = affinityMaskString.find_first_of (" ." , pos);
103102 size_t posNextComma = affinityMaskString.find_first_of (" ," , pos);
104103 std::string rootDeviceString = affinityMaskString.substr (pos, std::min (posNextDot, posNextComma) - pos);
105104 uint32_t rootDeviceIndex = static_cast <uint32_t >(std::stoul (rootDeviceString, nullptr , 0 ));
106- if (rootDeviceIndex < rootDeviceEnvironments.size ()) {
105+ if (rootDeviceIndex < numRootDevices) {
106+ auto hwInfo = rootDeviceEnvironments[rootDeviceIndex]->getHardwareInfo ();
107+ auto subDevicesCount = HwHelper::getSubDevicesCount (hwInfo);
108+
107109 pos += rootDeviceString.size ();
108110 if (posNextDot != std::string::npos &&
109111 affinityMaskString.at (pos) == ' .' && posNextDot < posNextComma) {
110112 pos++;
111113 std::string subDeviceString = affinityMaskString.substr (pos, posNextComma - pos);
112114 uint32_t subDeviceIndex = static_cast <uint32_t >(std::stoul (subDeviceString, nullptr , 0 ));
113- auto hwInfo = rootDeviceEnvironments[rootDeviceIndex]->getHardwareInfo ();
114- if (subDeviceIndex < HwHelper::getSubDevicesCount (hwInfo)) {
115- affinityMaskBitSet[rootDeviceIndex][subDeviceIndex] = true ;
115+ if (subDeviceIndex < subDevicesCount) {
116+ affinityMaskHelper[rootDeviceIndex].enableGenericSubDevice (subDeviceIndex);
116117 }
117118 } else {
118- std::fill (affinityMaskBitSet[rootDeviceIndex].begin (),
119- affinityMaskBitSet[rootDeviceIndex].end (),
120- true );
119+ affinityMaskHelper[rootDeviceIndex].enableAllGenericSubDevices (subDevicesCount);
121120 }
122121 }
123122 if (posNextComma == std::string::npos) {
@@ -126,31 +125,13 @@ void ExecutionEnvironment::parseAffinityMask() {
126125 pos = posNextComma + 1 ;
127126 }
128127
129- uint32_t offset = 0 ;
130- uint32_t affinityMask = 0 ;
131- for (uint32_t i = 0 ; i < affinityMaskBitSet.size (); i++) {
132- for (uint32_t j = 0 ; j < affinityMaskBitSet[i].size (); j++) {
133- if (affinityMaskBitSet[i][j] == true ) {
134- affinityMask |= (1UL << offset);
135- }
136- offset++;
137- }
138- }
139-
140- uint32_t currentMaskOffset = 0 ;
141128 std::vector<std::unique_ptr<RootDeviceEnvironment>> filteredEnvironments;
142- for (size_t i = 0u ; i < this ->rootDeviceEnvironments .size (); i++) {
143- auto hwInfo = rootDeviceEnvironments[i]->getHardwareInfo ();
144-
145- uint32_t currentDeviceMask = (affinityMask >> currentMaskOffset) & ((1UL << HwHelper::getSubDevicesCount (hwInfo)) - 1 );
146- bool isDeviceExposed = currentDeviceMask > 0 ;
147-
148- currentMaskOffset += HwHelper::getSubDevicesCount (hwInfo);
149- if (!isDeviceExposed) {
129+ for (uint32_t i = 0u ; i < numRootDevices; i++) {
130+ if (!affinityMaskHelper[i].isDeviceEnabled ()) {
150131 continue ;
151132 }
152133
153- rootDeviceEnvironments[i]->deviceAffinityMask = currentDeviceMask ;
134+ rootDeviceEnvironments[i]->deviceAffinityMask = affinityMaskHelper[i] ;
154135 filteredEnvironments.emplace_back (rootDeviceEnvironments[i].release ());
155136 }
156137
0 commit comments