Skip to content

Commit a591b0e

Browse files
committed
Move DeviceRange to PluginInterface and remove internal Device list
1 parent 19e888e commit a591b0e

File tree

3 files changed

+64
-78
lines changed

3 files changed

+64
-78
lines changed

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,7 @@ struct GenericPluginTy {
12661266
virtual GenericGlobalHandlerTy *createGlobalHandler() = 0;
12671267

12681268
/// Get the reference to the device with a certain device id.
1269-
GenericDeviceTy &getDevice(int32_t DeviceId) {
1269+
GenericDeviceTy &getDevice(int32_t DeviceId) const {
12701270
assert(isValidDeviceId(DeviceId) && "Invalid device id");
12711271
assert(Devices[DeviceId] && "Device is uninitialized");
12721272

@@ -1527,6 +1527,23 @@ struct GenericPluginTy {
15271527
/// object and return immediately.
15281528
int32_t async_barrier(omp_interop_val_t *Interop);
15291529

1530+
struct DevicesRangeTy {
1531+
using iterator = llvm::SmallVector<GenericDeviceTy *>::iterator;
1532+
1533+
iterator BeginIt;
1534+
iterator EndIt;
1535+
1536+
DevicesRangeTy(iterator BeginIt, iterator EndIt)
1537+
: BeginIt(BeginIt), EndIt(EndIt) {}
1538+
1539+
auto &begin() { return BeginIt; }
1540+
auto &end() { return EndIt; }
1541+
};
1542+
1543+
DevicesRangeTy getDevicesRange() {
1544+
return DevicesRangeTy(Devices.begin(), Devices.end());
1545+
}
1546+
15301547
private:
15311548
/// Indicates if the platform runtime has been fully initialized.
15321549
bool Initialized = false;

offload/plugins-nextgen/level_zero/include/L0Plugin.h

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,16 @@ namespace llvm::omp::target::plugin {
2626
/// Class implementing the LevelZero specific functionalities of the plugin.
2727
class LevelZeroPluginTy final : public GenericPluginTy {
2828
private:
29-
/// Number of devices available including subdevices
30-
uint32_t NumDevices = 0;
29+
struct DeviceInfoTy {
30+
L0DeviceIdTy Id;
31+
L0ContextTy *Driver;
32+
bool isRoot() const { return Id.SubId < 0 && Id.CCSId < 0; }
33+
};
34+
llvm::SmallVector<DeviceInfoTy> DetectedDevices;
3135

3236
/// Context (and Driver) specific data
3337
std::list<L0ContextTy> ContextList;
3438

35-
/// L0 device used by each OpenMP device
36-
using DeviceContainerTy = llvm::SmallVector<L0DeviceTy *>;
37-
DeviceContainerTy L0Devices;
38-
3939
// Table containing per-thread information using TLS
4040
L0ThreadTblTy ThreadTLSTable;
4141
// Table containing per-thread information for each device using TLS
@@ -51,6 +51,10 @@ class LevelZeroPluginTy final : public GenericPluginTy {
5151

5252
auto &getTLS() { return ThreadTLSTable.get(); }
5353

54+
/// Find L0 devices and initialize device properties.
55+
/// Returns number of devices reported to omptarget.
56+
Expected<int32_t> findDevices();
57+
5458
public:
5559
LevelZeroPluginTy() : GenericPluginTy(getTripleArch()) {}
5660
virtual ~LevelZeroPluginTy() {}
@@ -62,35 +66,10 @@ class LevelZeroPluginTy final : public GenericPluginTy {
6266

6367
static const auto &getOptions() { return Options; }
6468

65-
struct DevicesRangeTy {
66-
using iterator = DeviceContainerTy::iterator;
67-
68-
iterator BeginIt;
69-
iterator EndIt;
70-
71-
DevicesRangeTy(iterator BeginIt, iterator EndIt)
72-
: BeginIt(BeginIt), EndIt(EndIt) {}
73-
74-
auto &begin() { return BeginIt; }
75-
auto &end() { return EndIt; }
76-
};
77-
78-
auto getDevicesRange() {
79-
return DevicesRangeTy(L0Devices.begin(), L0Devices.end());
80-
}
81-
82-
/// Find L0 devices and initialize device properties.
83-
/// Returns number of devices reported to omptarget.
84-
Expected<int32_t> findDevices();
85-
8669
L0DeviceTy &getDeviceFromId(int32_t DeviceId) const {
87-
assert("Invalid device ID" && DeviceId >= 0 &&
88-
DeviceId < static_cast<int32_t>(L0Devices.size()));
89-
return *L0Devices[DeviceId];
70+
return static_cast<L0DeviceTy &>(getDevice(DeviceId));
9071
}
9172

92-
uint32_t getNumRootDevices() const { return NumDevices; }
93-
9473
AsyncQueueTy *getAsyncQueue() {
9574
auto *Queue = getTLS().getAsyncQueue();
9675
if (!Queue)

offload/plugins-nextgen/level_zero/src/L0Plugin.cpp

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -86,57 +86,29 @@ Expected<int32_t> LevelZeroPluginTy::findDevices() {
8686
return A.IsDiscrete;
8787
});
8888

89-
struct DeviceInfoTy {
90-
L0DeviceIdTy Id;
91-
L0ContextTy *Driver;
92-
bool isRoot() const { return Id.SubId < 0 && Id.CCSId < 0; }
93-
};
94-
95-
llvm::SmallVector<DeviceInfoTy> DevicesToAdd;
96-
9789
for (size_t RootId = 0; RootId < RootDevices.size(); RootId++) {
9890
const auto zeDevice = RootDevices[RootId].zeDevice;
9991
auto *RootDriver = RootDevices[RootId].Driver;
100-
DevicesToAdd.push_back(
101-
{{zeDevice, static_cast<int32_t>(RootId), -1, -1}, RootDriver});
102-
}
103-
NumDevices = DevicesToAdd.size();
104-
auto DeviceId = 0;
105-
for (auto &DeviceInfo : DevicesToAdd) {
106-
auto RootId = DeviceInfo.Id.RootId;
107-
auto SubId = DeviceInfo.Id.SubId;
108-
auto CCSId = DeviceInfo.Id.CCSId;
109-
auto zeDevice = DeviceInfo.Id.zeId;
110-
auto *Driver = DeviceInfo.Driver;
111-
112-
std::string IdStr = std::to_string(RootId) +
113-
(SubId < 0 ? "" : "." + std::to_string(SubId)) +
114-
(CCSId < 0 ? "" : "." + std::to_string(CCSId));
115-
116-
L0Devices.push_back(new L0DeviceTy(*this, DeviceId, getNumRootDevices(),
117-
zeDevice, *Driver, std::move(IdStr),
118-
CCSId < 0 ? 0 : CCSId /* ComputeIndex */
119-
));
120-
DeviceId++;
92+
DetectedDevices.push_back(DeviceInfoTy{
93+
{zeDevice, static_cast<int32_t>(RootId), -1, -1}, RootDriver});
12194
}
95+
int32_t NumDevices = DetectedDevices.size();
12296

123-
DP("Found %" PRIu32 " root devices, %" PRIu32 " total devices.\n",
124-
getNumRootDevices(), NumDevices);
97+
DP("Found %" PRIu32 " devices.\n", NumDevices);
12598
DP("List of devices (DeviceID[.SubID[.CCSID]])\n");
126-
for (auto &l0Device : L0Devices) {
127-
DP("-- %s\n", l0Device->getZeIdCStr());
128-
(void)l0Device; // silence warning
99+
for (auto &DeviceInfo : DetectedDevices) {
100+
(void)DeviceInfo; // to avoid unused variable warning in non-debug builds
101+
DP("-- Device %" PRIu32 "%s%s (zeDevice=%p) from Driver %p\n",
102+
DeviceInfo.Id.RootId,
103+
(DeviceInfo.Id.SubId < 0
104+
? ""
105+
: ("." + std::to_string(DeviceInfo.Id.SubId)).c_str()),
106+
(DeviceInfo.Id.CCSId < 0
107+
? ""
108+
: ("." + std::to_string(DeviceInfo.Id.CCSId)).c_str()),
109+
DPxPTR(DeviceInfo.Id.zeId), DPxPTR(DeviceInfo.Id.Driver));
129110
}
130-
131-
if (getDebugLevel() > 0) {
132-
DP("Root Device Information\n");
133-
for (uint32_t I = 0; I < getNumRootDevices(); I++) {
134-
auto &l0Device = getDeviceFromId(I);
135-
l0Device.reportDeviceInfo();
136-
}
137-
}
138-
139-
return getNumRootDevices();
111+
return NumDevices;
140112
}
141113

142114
Expected<int32_t> LevelZeroPluginTy::initImpl() {
@@ -163,7 +135,25 @@ Error LevelZeroPluginTy::deinitImpl() {
163135
GenericDeviceTy *LevelZeroPluginTy::createDevice(GenericPluginTy &Plugin,
164136
int32_t DeviceId,
165137
int32_t NumDevices) {
166-
return &getDeviceFromId(DeviceId);
138+
auto &DeviceInfo = DetectedDevices[DeviceId];
139+
auto RootId = DeviceInfo.Id.RootId;
140+
auto SubId = DeviceInfo.Id.SubId;
141+
auto CCSId = DeviceInfo.Id.CCSId;
142+
auto zeDevice = DeviceInfo.Id.zeId;
143+
auto *zeDriver = DeviceInfo.Driver;
144+
145+
std::string IdStr = std::to_string(RootId) +
146+
(SubId < 0 ? "" : "." + std::to_string(SubId)) +
147+
(CCSId < 0 ? "" : "." + std::to_string(CCSId));
148+
149+
auto *NewDevice = new L0DeviceTy(
150+
static_cast<LevelZeroPluginTy &>(Plugin), DeviceId, NumDevices, zeDevice,
151+
*zeDriver, std::move(IdStr), CCSId < 0 ? 0 : CCSId /* ComputeIndex */);
152+
if (NewDevice && getDebugLevel() > 0) {
153+
DP("Device %" PRIi32 " information\n", DeviceId);
154+
NewDevice->reportDeviceInfo();
155+
}
156+
return NewDevice;
167157
}
168158

169159
GenericGlobalHandlerTy *LevelZeroPluginTy::createGlobalHandler() {

0 commit comments

Comments
 (0)