Skip to content

Commit adb2421

Browse files
authored
[Offload] Refactor device information queries to use new tagging (#147318)
Instead using strings to look up device information (which is brittle and slow), use the new tags that the plugins specify when building the nodes.
1 parent 6b92a3b commit adb2421

File tree

2 files changed

+53
-76
lines changed

2 files changed

+53
-76
lines changed

offload/liboffload/src/Helpers.hpp

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,23 +75,16 @@ class InfoWriter {
7575
InfoWriter(InfoWriter &) = delete;
7676
~InfoWriter() = default;
7777

78-
template <typename T> llvm::Error write(llvm::Expected<T> &&Val) {
79-
if (Val)
80-
return getInfo(Size, Target, SizeRet, *Val);
81-
return Val.takeError();
78+
template <typename T> llvm::Error write(T Val) {
79+
return getInfo(Size, Target, SizeRet, Val);
8280
}
8381

84-
template <typename T>
85-
llvm::Error writeArray(llvm::Expected<T> &&Val, size_t Elems) {
86-
if (Val)
87-
return getInfoArray(Elems, Size, Target, SizeRet, *Val);
88-
return Val.takeError();
82+
template <typename T> llvm::Error writeArray(T Val, size_t Elems) {
83+
return getInfoArray(Elems, Size, Target, SizeRet, Val);
8984
}
9085

91-
llvm::Error writeString(llvm::Expected<llvm::StringRef> &&Val) {
92-
if (Val)
93-
return getInfoString(Size, Target, SizeRet, *Val);
94-
return Val.takeError();
86+
llvm::Error writeString(llvm::StringRef Val) {
87+
return getInfoString(Size, Target, SizeRet, Val);
9588
}
9689

9790
private:

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 47 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -299,78 +299,62 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
299299
return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
300300
};
301301

302-
// Find the info if it exists under any of the given names
303-
auto getInfoString =
304-
[&](std::vector<std::string> Names) -> llvm::Expected<const char *> {
305-
for (auto &Name : Names) {
306-
if (auto Entry = Device->Info.get(Name)) {
307-
if (!std::holds_alternative<std::string>((*Entry)->Value))
308-
return makeError(ErrorCode::BACKEND_FAILURE,
309-
"plugin returned incorrect type");
310-
return std::get<std::string>((*Entry)->Value).c_str();
311-
}
312-
}
313-
314-
return makeError(ErrorCode::UNIMPLEMENTED,
315-
"plugin did not provide a response for this information");
316-
};
317-
318-
auto getInfoXyz =
319-
[&](std::vector<std::string> Names) -> llvm::Expected<ol_dimensions_t> {
320-
for (auto &Name : Names) {
321-
if (auto Entry = Device->Info.get(Name)) {
322-
auto Node = *Entry;
323-
ol_dimensions_t Out{0, 0, 0};
324-
325-
auto getField = [&](StringRef Name, uint32_t &Dest) {
326-
if (auto F = Node->get(Name)) {
327-
if (!std::holds_alternative<size_t>((*F)->Value))
328-
return makeError(
329-
ErrorCode::BACKEND_FAILURE,
330-
"plugin returned incorrect type for dimensions element");
331-
Dest = std::get<size_t>((*F)->Value);
332-
} else
333-
return makeError(ErrorCode::BACKEND_FAILURE,
334-
"plugin didn't provide all values for dimensions");
335-
return Plugin::success();
336-
};
337-
338-
if (auto Res = getField("x", Out.x))
339-
return Res;
340-
if (auto Res = getField("y", Out.y))
341-
return Res;
342-
if (auto Res = getField("z", Out.z))
343-
return Res;
344-
345-
return Out;
346-
}
347-
}
302+
// These are not implemented by the plugin interface
303+
if (PropName == OL_DEVICE_INFO_PLATFORM)
304+
return Info.write<void *>(Device->Platform);
305+
if (PropName == OL_DEVICE_INFO_TYPE)
306+
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
307+
if (PropName >= OL_DEVICE_INFO_LAST)
308+
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
309+
"getDeviceInfo enum '%i' is invalid", PropName);
348310

311+
auto EntryOpt = Device->Info.get(static_cast<DeviceInfo>(PropName));
312+
if (!EntryOpt)
349313
return makeError(ErrorCode::UNIMPLEMENTED,
350314
"plugin did not provide a response for this information");
351-
};
315+
auto Entry = *EntryOpt;
352316

353317
switch (PropName) {
354-
case OL_DEVICE_INFO_PLATFORM:
355-
return Info.write<void *>(Device->Platform);
356-
case OL_DEVICE_INFO_TYPE:
357-
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
358318
case OL_DEVICE_INFO_NAME:
359-
return Info.writeString(getInfoString({"Device Name"}));
360319
case OL_DEVICE_INFO_VENDOR:
361-
return Info.writeString(getInfoString({"Vendor Name"}));
362-
case OL_DEVICE_INFO_DRIVER_VERSION:
363-
return Info.writeString(
364-
getInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
365-
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
366-
return Info.write(getInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/,
367-
"Maximum Block Dimensions" /*CUDA*/}));
368-
default:
369-
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
370-
"getDeviceInfo enum '%i' is invalid", PropName);
320+
case OL_DEVICE_INFO_DRIVER_VERSION: {
321+
// String values
322+
if (!std::holds_alternative<std::string>(Entry->Value))
323+
return makeError(ErrorCode::BACKEND_FAILURE,
324+
"plugin returned incorrect type");
325+
return Info.writeString(std::get<std::string>(Entry->Value).c_str());
371326
}
372327

373-
return Error::success();
328+
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: {
329+
// {x, y, z} triples
330+
ol_dimensions_t Out{0, 0, 0};
331+
332+
auto getField = [&](StringRef Name, uint32_t &Dest) {
333+
if (auto F = Entry->get(Name)) {
334+
if (!std::holds_alternative<size_t>((*F)->Value))
335+
return makeError(
336+
ErrorCode::BACKEND_FAILURE,
337+
"plugin returned incorrect type for dimensions element");
338+
Dest = std::get<size_t>((*F)->Value);
339+
} else
340+
return makeError(ErrorCode::BACKEND_FAILURE,
341+
"plugin didn't provide all values for dimensions");
342+
return Plugin::success();
343+
};
344+
345+
if (auto Res = getField("x", Out.x))
346+
return Res;
347+
if (auto Res = getField("y", Out.y))
348+
return Res;
349+
if (auto Res = getField("z", Out.z))
350+
return Res;
351+
352+
return Info.write(Out);
353+
}
354+
355+
default:
356+
llvm_unreachable("Unimplemented device info");
357+
}
374358
}
375359

376360
Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,

0 commit comments

Comments
 (0)