Skip to content

Commit fb17bb2

Browse files
authored
[AMD] Use hipGetProcAddress to query HIP symbols (#7350)
This commit switches the HIP backend to use hipGetProcAddress to query HIP symbols. This addresses the issue we had regarding hipGetDeviceProperties. Note this prepares for upcoming ROCm 7 release. It also drops support for ROCm 5 effectively.
1 parent 3b41514 commit fb17bb2

File tree

2 files changed

+56
-39
lines changed

2 files changed

+56
-39
lines changed

third_party/amd/backend/driver.c

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
#define __HIP_PLATFORM_AMD__
2-
// clang-format off
3-
// hip_depreated.h needs definitions from hip_runtime.h.
42
#include <hip/hip_runtime.h>
5-
#include <hip/hip_deprecated.h>
6-
// clang-format on
3+
#include <hip/hip_runtime_api.h>
74
#define PY_SSIZE_T_CLEAN
85
#include <Python.h>
96
#include <dlfcn.h>
@@ -14,28 +11,9 @@
1411
// code should substitute the search path placeholder.
1512
static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};
1613

17-
// The list of HIP dynamic library symbols and their signature we are interested
18-
// in this file.
19-
// |FOR_EACH_ERR_FN| is a macro to process APIs that return hipError_t;
20-
// |FOR_EACH_STR_FN| is a macro to process APIs that return const char *.
21-
//
22-
// HIP 6.0 introduced an updated hipGetDeviceProperties API under a new symbol,
23-
// hipGetDevicePropertiesR0600. However, the associated hipDeviceProp_t was
24-
// directly updated with breaking changes to match hipGetDevicePropertiesR0600
25-
// in the header file. We include the header file from HIP 6.0. So here if we
26-
// use hipGetDeviceProperties together with hipDeviceProp_t we will use the
27-
// old API with a new struct definition and mess up the interpretation.
28-
//
29-
// This is a known issue: https://github.com/ROCm/ROCm/issues/2728.
30-
//
31-
// For now explicitly defer to the old hipDeviceProp_t struct. This should work
32-
// for both 5.x and 6.x. In the long term we need to switch to use
33-
// hipGetProcAddress once available:
34-
// https://github.com/ROCm/clr/commit/0479cdb3dd30ef58718cad44e424bd793c394cc0
3514
#define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \
3615
FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \
37-
FOR_EACH_ERR_FN(hipGetDeviceProperties, hipDeviceProp_tR0000 *prop, \
38-
int deviceId) \
16+
FOR_EACH_ERR_FN(hipGetDeviceProperties, hipDeviceProp_t *prop, int deviceId) \
3917
FOR_EACH_ERR_FN(hipModuleLoadDataEx, hipModule_t *module, const void *image, \
4018
unsigned int numOptions, hipJitOption *options, \
4119
void **optionValues) \
@@ -80,15 +58,34 @@ bool initSymbolTable() {
8058
return false;
8159
}
8260

83-
// Resolve all symbols we are interested in.
61+
typedef hipError_t (*hipGetProcAddress_fn)(
62+
const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
63+
hipDriverProcAddressQueryResult *symbolStatus);
64+
hipGetProcAddress_fn hipGetProcAddress;
8465
dlerror(); // Clear existing errors
8566
const char *error = NULL;
67+
*(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
68+
error = dlerror();
69+
if (error) {
70+
PyErr_SetString(PyExc_RuntimeError,
71+
"cannot query 'hipGetProcAddress' from libamdhip64.so");
72+
dlclose(lib);
73+
return false;
74+
}
75+
76+
// Resolve all symbols we are interested in.
77+
int hipVersion = HIP_VERSION;
78+
uint64_t hipFlags = 0;
79+
hipDriverProcAddressQueryResult symbolStatus;
80+
hipError_t status = hipSuccess;
8681
#define QUERY_EACH_FN(hipSymbolName, ...) \
87-
*(void **)&hipSymbolTable.hipSymbolName = dlsym(lib, #hipSymbolName); \
88-
error = dlerror(); \
89-
if (error) { \
82+
status = hipGetProcAddress(#hipSymbolName, \
83+
(void **)&hipSymbolTable.hipSymbolName, \
84+
hipVersion, hipFlags, &symbolStatus); \
85+
if (status != hipSuccess) { \
9086
PyErr_SetString(PyExc_RuntimeError, \
91-
"cannot query " #hipSymbolName " from libamdhip64.so"); \
87+
"cannot get address for '" #hipSymbolName \
88+
"' from libamdhip64.so"); \
9289
dlclose(lib); \
9390
return false; \
9491
}
@@ -127,7 +124,7 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
127124
if (!PyArg_ParseTuple(args, "i", &device_id))
128125
return NULL;
129126

130-
hipDeviceProp_tR0000 props;
127+
hipDeviceProp_t props;
131128
HIP_CHECK(hipSymbolTable.hipGetDeviceProperties(&props, device_id));
132129

133130
// create a struct to hold device properties

third_party/amd/backend/driver.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def format_of(ty):
265265
src = f"""
266266
#define __HIP_PLATFORM_AMD__
267267
#include <hip/hip_runtime.h>
268+
#include <hip/hip_runtime_api.h>
268269
#include <Python.h>
269270
#include <dlfcn.h>
270271
#include <stdbool.h>
@@ -329,17 +330,36 @@ def format_of(ty):
329330
return false;
330331
}}
331332
332-
// Resolve all symbols we are interested in.
333+
typedef hipError_t (*hipGetProcAddress_fn)(
334+
const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
335+
hipDriverProcAddressQueryResult *symbolStatus);
336+
hipGetProcAddress_fn hipGetProcAddress;
333337
dlerror(); // Clear existing errors
334338
const char *error = NULL;
335-
#define QUERY_EACH_FN(hipSymbolName, ...) \\
336-
*(void **)&hipSymbolTable.hipSymbolName = dlsym(lib, #hipSymbolName); \\
337-
error = dlerror(); \\
338-
if (error) {{ \\
339-
PyErr_SetString(PyExc_RuntimeError, \\
340-
"cannot query " #hipSymbolName " from libamdhip64.so"); \\
341-
dlclose(lib); \\
342-
return false; \\
339+
*(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
340+
error = dlerror();
341+
if (error) {{
342+
PyErr_SetString(PyExc_RuntimeError,
343+
"cannot query 'hipGetProcAddress' from libamdhip64.so");
344+
dlclose(lib);
345+
return false;
346+
}}
347+
348+
// Resolve all symbols we are interested in.
349+
int hipVersion = HIP_VERSION;
350+
uint64_t hipFlags = 0;
351+
hipDriverProcAddressQueryResult symbolStatus;
352+
hipError_t status = hipSuccess;
353+
#define QUERY_EACH_FN(hipSymbolName, ...) \
354+
status = hipGetProcAddress(#hipSymbolName, \
355+
(void **)&hipSymbolTable.hipSymbolName, \
356+
hipVersion, hipFlags, &symbolStatus); \
357+
if (status != hipSuccess) {{ \
358+
PyErr_SetString(PyExc_RuntimeError, \
359+
"cannot get address for '" #hipSymbolName \
360+
"' from libamdhip64.so"); \
361+
dlclose(lib); \
362+
return false; \
343363
}}
344364
345365
HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)

0 commit comments

Comments
 (0)