Skip to content

Commit 9890a51

Browse files
authored
Fix GPU handle pool singleton aliasing (#525)
Addresses handle pool singleton sharing issue between different GPU operation types in ROCm/HIP backend.
1 parent 2c4c706 commit 9890a51

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

jaxlib/gpu/vendor.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,10 +431,12 @@ typedef hipDoubleComplex gpuDoubleComplex;
431431
typedef hipComplex gpublasComplex;
432432
typedef hipDoubleComplex gpublasDoubleComplex;
433433

434-
typedef hipsolverHandle_t gpusolverDnHandle_t;
434+
// Create unique opaque pointer types for proper singleton separation - BLAS and SOLVER only
435+
typedef struct hipblasHandle_* gpublasHandle_t;
436+
typedef struct hipsolverHandle_* gpusolverDnHandle_t;
437+
435438
typedef hipblasFillMode_t gpublasFillMode_t;
436439
typedef hipsolverFillMode_t gpusolverFillMode_t;
437-
typedef hipblasHandle_t gpublasHandle_t;
438440
typedef hipblasOperation_t gpublasOperation_t;
439441
typedef hipblasStatus_t gpublasStatus_t;
440442
typedef hipCtx_t gpuContext_t;
@@ -480,8 +482,10 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
480482
#define GPU_C_64F HIP_C_64F
481483
#define GPU_R_64F HIP_R_64F
482484

483-
#define gpublasCreate hipblasCreate
485+
// Wrapper functions for BLAS handles to ensure unique types
486+
#define gpublasCreate(handle) hipblasCreate(reinterpret_cast<hipblasHandle_t*>(handle))
484487
#define gpublasSetStream hipblasSetStream
488+
485489
#define gpublasSgeqrfBatched hipblasSgeqrfBatched
486490
#define gpublasDgeqrfBatched hipblasDgeqrfBatched
487491
#define gpublasCgeqrfBatched hipblasCgeqrfBatched
@@ -531,8 +535,10 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
531535
#define GPUDNN_LSTM miopenLSTM
532536
#define GPUDNN_BIDIRECTIONAL miopenRNNbidirection
533537

534-
#define gpusolverDnCreate hipsolverCreate
538+
// Wrapper functions for SOLVER handles to ensure unique types
539+
#define gpusolverDnCreate(handle) hipsolverCreate(reinterpret_cast<hipsolverHandle_t*>(handle))
535540
#define gpusolverDnSetStream hipsolverSetStream
541+
536542
#define gpusolverDnCreateSyevjInfo hipsolverCreateSyevjInfo
537543
#define gpusolverDnDestroySyevjInfo hipsolverDestroySyevjInfo
538544
#define gpusolverDnSgeqrf hipsolverSgeqrf

0 commit comments

Comments
 (0)