@@ -27,6 +27,34 @@ static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};
27
27
FOR_EACH_ERR_FN(hipFuncGetAttribute, int *, hipFunction_attribute attr, \
28
28
hipFunction_t function)
29
29
30
+ // HIP driver version format: HIP_VERSION_MAJOR * 10000000 + HIP_VERSION_MINOR *
31
+ // 100000 + HIP_VERSION_PATCH.
32
+ #define TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION (version ) ((version) / 10000000)
33
+ #define TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION (version ) \
34
+ (((version) % 10000000) / 100000)
35
+ #define TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION (version ) ((version) % 100000)
36
+ #define TRITON_HIP_DRIVER_REQ_MAJOR_VERSION (HIP_VERSION_MAJOR)
37
+
38
+ // #define TRITON_HIP_DRIVER_DBG_VERSION
39
+ #ifdef TRITON_HIP_DRIVER_DBG_VERSION
40
+ #define TRITON_HIP_DRIVER_LOG_VERSION (version , msgBuff ) \
41
+ do { \
42
+ snprintf(msgBuff, sizeof(msgBuff), "libamdhip64 version is: %d.%d.%d", \
43
+ TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version), \
44
+ TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version), \
45
+ TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version)); \
46
+ printf("%s\n", msgBuff); \
47
+ } while (0);
48
+ #else
49
+ #define TRITON_HIP_DRIVER_LOG_VERSION (version , msgBuff ) \
50
+ do { \
51
+ (void)msgBuff; \
52
+ (void)(version); \
53
+ } while (0);
54
+ #endif
55
+
56
+ #define TRITON_HIP_MSG_BUFF_SIZE (1024U)
57
+
30
58
// The HIP symbol table for holding resolved dynamic library symbols.
31
59
struct HIPSymbolTable {
32
60
#define DEFINE_EACH_ERR_FIELD (hipSymbolName , ...) \
@@ -39,36 +67,75 @@ struct HIPSymbolTable {
39
67
40
68
static struct HIPSymbolTable hipSymbolTable ;
41
69
42
- bool initSymbolTable () {
43
- // Use the HIP runtime library loaded into the existing process if it exits.
44
- void * lib = dlopen ("libamdhip64.so" , RTLD_NOLOAD );
45
- if (lib ) {
46
- // printf("[triton] chosen loaded libamdhip64.so in the process\n");
70
+ static int checkDriverVersion (void * lib ) {
71
+ int hipVersion = -1 ;
72
+ const char * error = NULL ;
73
+ typedef hipError_t (* hipDriverGetVersion_fn )(int * driverVersion );
74
+ hipDriverGetVersion_fn hipDriverGetVersion ;
75
+ dlerror (); // Clear existing errors
76
+ hipDriverGetVersion =
77
+ (hipDriverGetVersion_fn )dlsym (lib , "hipDriverGetVersion" );
78
+ error = dlerror ();
79
+ if (error ) {
80
+ PyErr_SetString (PyExc_RuntimeError ,
81
+ "cannot query 'hipDriverGetVersion' from libamdhip64.so" );
82
+ dlclose (lib );
83
+ return -1 ;
47
84
}
48
85
49
- // Otherwise, go through the list of search paths to dlopen the first HIP
50
- // driver library.
51
- if (!lib ) {
52
- int n = sizeof (hipLibSearchPaths ) / sizeof (hipLibSearchPaths [0 ]);
53
- for (int i = 0 ; i < n ; ++ i ) {
54
- void * handle = dlopen (hipLibSearchPaths [i ], RTLD_LAZY | RTLD_LOCAL );
55
- if (handle ) {
56
- lib = handle ;
57
- // printf("[triton] chosen %s\n", hipLibSearchPaths[i]);
58
- }
86
+ (void )hipDriverGetVersion (& hipVersion );
87
+ char msgBuff [TRITON_HIP_MSG_BUFF_SIZE ] = {0 };
88
+
89
+ const int hipMajVersion = TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION (hipVersion );
90
+ if (hipMajVersion < TRITON_HIP_DRIVER_REQ_MAJOR_VERSION ) {
91
+ const int hipMinVersion =
92
+ TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION (hipVersion );
93
+ const int hipPatchVersion =
94
+ TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION (hipVersion );
95
+ snprintf (msgBuff , sizeof (msgBuff ),
96
+ "libamdhip64 version %d.%d.%d is not supported! Required major "
97
+ "version is >=%d." ,
98
+ hipMajVersion , hipMinVersion , hipPatchVersion ,
99
+ TRITON_HIP_DRIVER_REQ_MAJOR_VERSION );
100
+ PyErr_SetString (PyExc_RuntimeError , msgBuff );
101
+ dlclose (lib );
102
+ return -1 ;
103
+ }
104
+
105
+ TRITON_HIP_DRIVER_LOG_VERSION (hipVersion , msgBuff );
106
+
107
+ return hipVersion ;
108
+ }
109
+
110
+ bool initSymbolTable () {
111
+ void * lib ;
112
+
113
+ // Go through the list of search paths to dlopen the first HIP driver library.
114
+ int n = sizeof (hipLibSearchPaths ) / sizeof (hipLibSearchPaths [0 ]);
115
+ for (int i = 0 ; i < n ; ++ i ) {
116
+ void * handle = dlopen (hipLibSearchPaths [i ], RTLD_LAZY | RTLD_LOCAL );
117
+ if (handle ) {
118
+ lib = handle ;
119
+ // printf("[triton] chosen %s\n", hipLibSearchPaths[i]);
59
120
}
60
121
}
122
+
61
123
if (!lib ) {
62
124
PyErr_SetString (PyExc_RuntimeError , "cannot open libamdhip64.so" );
63
125
return false;
64
126
}
65
127
128
+ int hipVersion = checkDriverVersion (lib );
129
+ if (hipVersion == -1 )
130
+ return false;
131
+
132
+ const char * error = NULL ;
66
133
typedef hipError_t (* hipGetProcAddress_fn )(
67
134
const char * symbol , void * * pfn , int hipVersion , uint64_t hipFlags ,
68
135
hipDriverProcAddressQueryResult * symbolStatus );
69
136
hipGetProcAddress_fn hipGetProcAddress ;
70
137
dlerror (); // Clear existing errors
71
- const char * error = NULL ;
138
+
72
139
* (void * * )& hipGetProcAddress = dlsym (lib , "hipGetProcAddress" );
73
140
error = dlerror ();
74
141
if (error ) {
@@ -79,7 +146,6 @@ bool initSymbolTable() {
79
146
}
80
147
81
148
// Resolve all symbols we are interested in.
82
- int hipVersion = HIP_VERSION ;
83
149
uint64_t hipFlags = 0 ;
84
150
hipDriverProcAddressQueryResult symbolStatus ;
85
151
hipError_t status = hipSuccess ;
@@ -106,8 +172,9 @@ static inline void gpuAssert(hipError_t code, const char *file, int line) {
106
172
{
107
173
const char * prefix = "Triton Error [HIP]: " ;
108
174
const char * str = hipSymbolTable .hipGetErrorString (code );
109
- char err [1024 ] = {0 };
110
- snprintf (err , 1024 , "%s Code: %d, Messsage: %s" , prefix , code , str );
175
+ char err [TRITON_HIP_MSG_BUFF_SIZE ] = {0 };
176
+ snprintf (err , sizeof (err ), "%s Code: %d, Messsage: %s" , prefix , code ,
177
+ str );
111
178
PyGILState_STATE gil_state ;
112
179
gil_state = PyGILState_Ensure ();
113
180
PyErr_SetString (PyExc_RuntimeError , err );
0 commit comments