Skip to content

Commit 7f341eb

Browse files
authored
[AMD] Check libamdhip version during initialization (#7501)
This commit introduces a runtime driver version check to ensure that the major version meets a minimum required version and return a descriptive error in case of outdated (not supported, incompatible) runtime driver library.
1 parent 1be3fa1 commit 7f341eb

File tree

1 file changed

+86
-19
lines changed

1 file changed

+86
-19
lines changed

third_party/amd/backend/driver.c

Lines changed: 86 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,34 @@ static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};
2727
FOR_EACH_ERR_FN(hipFuncGetAttribute, int *, hipFunction_attribute attr, \
2828
hipFunction_t function)
2929

30+
// HIP driver version format: HIP_VERSION_MAJOR * 10000000 + HIP_VERSION_MINOR *
31+
// 100000 + HIP_VERSION_PATCH.
32+
#define TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version) ((version) / 10000000)
33+
#define TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version) \
34+
(((version) % 10000000) / 100000)
35+
#define TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version) ((version) % 100000)
36+
#define TRITON_HIP_DRIVER_REQ_MAJOR_VERSION (HIP_VERSION_MAJOR)
37+
38+
// #define TRITON_HIP_DRIVER_DBG_VERSION
39+
#ifdef TRITON_HIP_DRIVER_DBG_VERSION
40+
#define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
41+
do { \
42+
snprintf(msgBuff, sizeof(msgBuff), "libamdhip64 version is: %d.%d.%d", \
43+
TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version), \
44+
TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version), \
45+
TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version)); \
46+
printf("%s\n", msgBuff); \
47+
} while (0);
48+
#else
49+
#define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
50+
do { \
51+
(void)msgBuff; \
52+
(void)(version); \
53+
} while (0);
54+
#endif
55+
56+
#define TRITON_HIP_MSG_BUFF_SIZE (1024U)
57+
3058
// The HIP symbol table for holding resolved dynamic library symbols.
3159
struct HIPSymbolTable {
3260
#define DEFINE_EACH_ERR_FIELD(hipSymbolName, ...) \
@@ -39,36 +67,75 @@ struct HIPSymbolTable {
3967

4068
static struct HIPSymbolTable hipSymbolTable;
4169

42-
bool initSymbolTable() {
43-
// Use the HIP runtime library loaded into the existing process if it exits.
44-
void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD);
45-
if (lib) {
46-
// printf("[triton] chosen loaded libamdhip64.so in the process\n");
70+
static int checkDriverVersion(void *lib) {
71+
int hipVersion = -1;
72+
const char *error = NULL;
73+
typedef hipError_t (*hipDriverGetVersion_fn)(int *driverVersion);
74+
hipDriverGetVersion_fn hipDriverGetVersion;
75+
dlerror(); // Clear existing errors
76+
hipDriverGetVersion =
77+
(hipDriverGetVersion_fn)dlsym(lib, "hipDriverGetVersion");
78+
error = dlerror();
79+
if (error) {
80+
PyErr_SetString(PyExc_RuntimeError,
81+
"cannot query 'hipDriverGetVersion' from libamdhip64.so");
82+
dlclose(lib);
83+
return -1;
4784
}
4885

49-
// Otherwise, go through the list of search paths to dlopen the first HIP
50-
// driver library.
51-
if (!lib) {
52-
int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
53-
for (int i = 0; i < n; ++i) {
54-
void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
55-
if (handle) {
56-
lib = handle;
57-
// printf("[triton] chosen %s\n", hipLibSearchPaths[i]);
58-
}
86+
(void)hipDriverGetVersion(&hipVersion);
87+
char msgBuff[TRITON_HIP_MSG_BUFF_SIZE] = {0};
88+
89+
const int hipMajVersion = TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(hipVersion);
90+
if (hipMajVersion < TRITON_HIP_DRIVER_REQ_MAJOR_VERSION) {
91+
const int hipMinVersion =
92+
TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(hipVersion);
93+
const int hipPatchVersion =
94+
TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(hipVersion);
95+
snprintf(msgBuff, sizeof(msgBuff),
96+
"libamdhip64 version %d.%d.%d is not supported! Required major "
97+
"version is >=%d.",
98+
hipMajVersion, hipMinVersion, hipPatchVersion,
99+
TRITON_HIP_DRIVER_REQ_MAJOR_VERSION);
100+
PyErr_SetString(PyExc_RuntimeError, msgBuff);
101+
dlclose(lib);
102+
return -1;
103+
}
104+
105+
TRITON_HIP_DRIVER_LOG_VERSION(hipVersion, msgBuff);
106+
107+
return hipVersion;
108+
}
109+
110+
bool initSymbolTable() {
111+
void *lib;
112+
113+
// Go through the list of search paths to dlopen the first HIP driver library.
114+
int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
115+
for (int i = 0; i < n; ++i) {
116+
void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
117+
if (handle) {
118+
lib = handle;
119+
// printf("[triton] chosen %s\n", hipLibSearchPaths[i]);
59120
}
60121
}
122+
61123
if (!lib) {
62124
PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so");
63125
return false;
64126
}
65127

128+
int hipVersion = checkDriverVersion(lib);
129+
if (hipVersion == -1)
130+
return false;
131+
132+
const char *error = NULL;
66133
typedef hipError_t (*hipGetProcAddress_fn)(
67134
const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
68135
hipDriverProcAddressQueryResult *symbolStatus);
69136
hipGetProcAddress_fn hipGetProcAddress;
70137
dlerror(); // Clear existing errors
71-
const char *error = NULL;
138+
72139
*(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
73140
error = dlerror();
74141
if (error) {
@@ -79,7 +146,6 @@ bool initSymbolTable() {
79146
}
80147

81148
// Resolve all symbols we are interested in.
82-
int hipVersion = HIP_VERSION;
83149
uint64_t hipFlags = 0;
84150
hipDriverProcAddressQueryResult symbolStatus;
85151
hipError_t status = hipSuccess;
@@ -106,8 +172,9 @@ static inline void gpuAssert(hipError_t code, const char *file, int line) {
106172
{
107173
const char *prefix = "Triton Error [HIP]: ";
108174
const char *str = hipSymbolTable.hipGetErrorString(code);
109-
char err[1024] = {0};
110-
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str);
175+
char err[TRITON_HIP_MSG_BUFF_SIZE] = {0};
176+
snprintf(err, sizeof(err), "%s Code: %d, Messsage: %s", prefix, code,
177+
str);
111178
PyGILState_STATE gil_state;
112179
gil_state = PyGILState_Ensure();
113180
PyErr_SetString(PyExc_RuntimeError, err);

0 commit comments

Comments
 (0)