Skip to content

Commit 8c21680

Browse files
authored
[ROCm] prefer hip interfaces over roc during hipify (#22394)
### Description Change the hipify step to remove the -roc option to hipify-perl. This will prefer hipblas over rocblas. rocblas can still be called directly such as in TunableOp. ### Motivation and Context hip interfaces are preferred over roc for porting from cuda to hip. Calling roc interfaces is meant for ROCm-specific enhancements or extensions.
1 parent ec7aa63 commit 8c21680

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+689
-242
lines changed

cmake/onnxruntime_kernel_explorer.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ elseif (onnxruntime_USE_ROCM)
6464
)
6565
auto_set_source_files_hip_language(${kernel_explorer_kernel_srcs} ${kernel_explorer_rocm_kernel_srcs})
6666
target_sources(kernel_explorer PRIVATE ${kernel_explorer_rocm_kernel_srcs})
67-
target_compile_definitions(kernel_explorer PRIVATE __HIP_PLATFORM_AMD__=1 __HIP_PLATFORM_HCC__=1)
67+
target_compile_definitions(kernel_explorer PRIVATE __HIP_PLATFORM_AMD__=1 __HIP_PLATFORM_HCC__=1 HIPBLAS_V2)
6868
if (onnxruntime_USE_COMPOSABLE_KERNEL)
6969
target_compile_definitions(kernel_explorer PRIVATE USE_COMPOSABLE_KERNEL)
7070
if (onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE)

cmake/onnxruntime_providers_rocm.cmake

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
find_package(HIP)
1010
find_package(hiprand REQUIRED)
11-
find_package(rocblas REQUIRED)
11+
find_package(hipblas REQUIRED)
1212
find_package(MIOpen REQUIRED)
1313
find_package(hipfft REQUIRED)
1414

@@ -50,7 +50,7 @@
5050
find_library(RCCL_LIB rccl REQUIRED)
5151
find_library(ROCTRACER_LIB roctracer64 REQUIRED)
5252
find_package(rocm_smi REQUIRED)
53-
set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${ROCM_SMI_LIBRARY} ${RCCL_LIB} ${ROCTRACER_LIB})
53+
set(ONNXRUNTIME_ROCM_LIBS roc::hipblas MIOpen hip::hipfft ${ROCM_SMI_LIBRARY} ${RCCL_LIB} ${ROCTRACER_LIB})
5454
include_directories(${ROCM_SMI_INCLUDE_DIR})
5555
link_directories(${ROCM_SMI_LIB_DIR})
5656

@@ -155,6 +155,7 @@
155155

156156
set_target_properties(onnxruntime_providers_rocm PROPERTIES LINKER_LANGUAGE CXX)
157157
set_target_properties(onnxruntime_providers_rocm PROPERTIES FOLDER "ONNXRuntime")
158+
target_compile_definitions(onnxruntime_providers_rocm PRIVATE HIPBLAS_V2)
158159

159160
if (onnxruntime_ENABLE_TRAINING)
160161
target_include_directories(onnxruntime_providers_rocm PRIVATE ${ORTTRAINING_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining ${MPI_CXX_INCLUDE_DIRS})

include/onnxruntime/core/providers/rocm/rocm_context.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include "core/providers/custom_op_context.h"
88
#include <hip/hip_runtime.h>
99
#include <miopen/miopen.h>
10-
#include <rocblas/rocblas.h>
10+
#include <hipblas/hipblas.h>
1111

1212
namespace Ort {
1313

@@ -16,7 +16,7 @@ namespace Custom {
1616
struct RocmContext : public CustomOpContext {
1717
hipStream_t hip_stream = {};
1818
miopenHandle_t miopen_handle = {};
19-
rocblas_handle rblas_handle = {};
19+
hipblasHandle_t blas_handle = {};
2020

2121
void Init(const OrtKernelContext& kernel_ctx) {
2222
const auto& ort_api = Ort::GetApi();
@@ -40,11 +40,11 @@ struct RocmContext : public CustomOpContext {
4040

4141
resource = {};
4242
status = ort_api.KernelContext_GetResource(
43-
&kernel_ctx, ORT_ROCM_RESOURCE_VERSION, RocmResource::rocblas_handle_t, &resource);
43+
&kernel_ctx, ORT_ROCM_RESOURCE_VERSION, RocmResource::hipblas_handle_t, &resource);
4444
if (status) {
45-
ORT_CXX_API_THROW("failed to fetch rocblas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
45+
ORT_CXX_API_THROW("failed to fetch hipblas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
4646
}
47-
rblas_handle = reinterpret_cast<rocblas_handle>(resource);
47+
blas_handle = reinterpret_cast<hipblasHandle_t>(resource);
4848
}
4949
};
5050

include/onnxruntime/core/providers/rocm/rocm_resource.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
enum RocmResource : int {
99
hip_stream_t = rocm_resource_offset,
1010
miopen_handle_t,
11-
rocblas_handle_t
11+
hipblas_handle_t
1212
};

onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ Status LaunchLongformerSoftmaxKernel(
396396
cudaDataType_t Atype;
397397
cudaDataType_t Btype;
398398
cudaDataType_t Ctype;
399-
cudaDataType_t resultType;
399+
cublasComputeType_t resultType;
400400
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
401401

402402
__half one_fp16, zero_fp16;
@@ -412,7 +412,7 @@ Status LaunchLongformerSoftmaxKernel(
412412
Atype = CUDA_R_16F;
413413
Btype = CUDA_R_16F;
414414
Ctype = CUDA_R_16F;
415-
resultType = CUDA_R_16F;
415+
resultType = CUBLAS_COMPUTE_16F;
416416
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
417417
} else {
418418
one_fp32 = 1.f;
@@ -423,7 +423,7 @@ Status LaunchLongformerSoftmaxKernel(
423423
Atype = CUDA_R_32F;
424424
Btype = CUDA_R_32F;
425425
Ctype = CUDA_R_32F;
426-
resultType = CUDA_R_32F;
426+
resultType = CUBLAS_COMPUTE_32F;
427427
}
428428

429429
// Strided batch matrix multiply

onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ Status LaunchLongformerSoftmaxSimpleKernel(
221221
cudaDataType_t Atype;
222222
cudaDataType_t Btype;
223223
cudaDataType_t Ctype;
224-
cudaDataType_t resultType;
224+
cublasComputeType_t resultType;
225225
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
226226

227227
__half one_fp16, zero_fp16;
@@ -237,7 +237,7 @@ Status LaunchLongformerSoftmaxSimpleKernel(
237237
Atype = CUDA_R_16F;
238238
Btype = CUDA_R_16F;
239239
Ctype = CUDA_R_16F;
240-
resultType = CUDA_R_16F;
240+
resultType = CUBLAS_COMPUTE_16F;
241241
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
242242
} else {
243243
one_fp32 = 1.f;
@@ -248,7 +248,7 @@ Status LaunchLongformerSoftmaxSimpleKernel(
248248
Atype = CUDA_R_32F;
249249
Btype = CUDA_R_32F;
250250
Ctype = CUDA_R_32F;
251-
resultType = CUDA_R_32F;
251+
resultType = CUBLAS_COMPUTE_32F;
252252
}
253253

254254
// Strided batch matrix multiply

onnxruntime/contrib_ops/rocm/bert/attention.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
8484
Tensor* present = context->Output(kPresentOutputIndex, present_shape);
8585

8686
auto stream = Stream(context);
87-
rocblas_handle rocblas = GetRocblasHandle(context);
87+
hipblasHandle_t hipblas = GetHipblasHandle(context);
8888

8989
using HipT = typename ToHipType<T>::MappedType;
9090
using QkvProjectGeneric = GemmPermuteGenericPipeline<HipT>;
@@ -113,7 +113,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
113113
auto& params = gemm_permute_params;
114114
params.tuning_ctx = GetTuningContext();
115115
params.stream = context->GetComputeStream();
116-
params.handle = rocblas;
116+
params.handle = hipblas;
117117
params.attention = &attn;
118118
params.device_prop = &device_prop;
119119

@@ -179,7 +179,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
179179
auto& params = gemm_softmax_gemm_permute_params;
180180
params.tuning_ctx = GetTuningContext();
181181
params.stream = context->GetComputeStream();
182-
params.handle = rocblas;
182+
params.handle = hipblas;
183183
params.attention = &attn;
184184
params.device_prop = &device_prop;
185185
// FIXME: the params.scale seems to be different from AttentionParameters::scale;

onnxruntime/contrib_ops/rocm/bert/attention_impl.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ Status DecoderQkvToContext(
182182
const hipDeviceProp_t& prop,
183183
RocmTuningContext* tuning_ctx,
184184
Stream* ort_stream,
185-
rocblas_handle& rocblas,
185+
hipblasHandle_t& hipblas,
186186
const size_t element_size,
187187
const int batch_size,
188188
const int sequence_length,
@@ -284,7 +284,7 @@ Status DecoderQkvToContext(
284284
const int strideB = sequence_length * head_size;
285285
if (use_past && static_kv) {
286286
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
287-
tuning_ctx, ort_stream, rocblas,
287+
tuning_ctx, ort_stream, hipblas,
288288
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
289289
kv_sequence_length, sequence_length, head_size,
290290
/*alpha=*/rsqrt_head_size,
@@ -295,7 +295,7 @@ Status DecoderQkvToContext(
295295
BN));
296296
} else {
297297
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
298-
tuning_ctx, ort_stream, rocblas,
298+
tuning_ctx, ort_stream, hipblas,
299299
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
300300
kv_sequence_length, sequence_length, head_size,
301301
/*alpha=*/rsqrt_head_size,
@@ -320,7 +320,7 @@ Status DecoderQkvToContext(
320320
// compute P*V (as V*P), and store in scratch3: BxNxSxH
321321
if (use_past && static_kv) {
322322
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
323-
tuning_ctx, ort_stream, rocblas,
323+
tuning_ctx, ort_stream, hipblas,
324324
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
325325
head_size, sequence_length, kv_sequence_length,
326326
/*alpha=*/1.0f,
@@ -331,7 +331,7 @@ Status DecoderQkvToContext(
331331
BN));
332332
} else {
333333
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
334-
tuning_ctx, ort_stream, rocblas,
334+
tuning_ctx, ort_stream, hipblas,
335335
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
336336
head_size, sequence_length, kv_sequence_length,
337337
/*alpha=*/1.0f,
@@ -351,7 +351,7 @@ Status LaunchDecoderAttentionKernel(
351351
const hipDeviceProp_t& prop,
352352
RocmTuningContext* tuning_ctx,
353353
Stream* stream,
354-
rocblas_handle& rocblas,
354+
hipblasHandle_t& hipblas,
355355
const size_t element_size,
356356
const int batch_size,
357357
const int sequence_length,
@@ -378,7 +378,7 @@ Status LaunchDecoderAttentionKernel(
378378
prop,
379379
tuning_ctx,
380380
stream,
381-
rocblas,
381+
hipblas,
382382
element_size,
383383
batch_size,
384384
sequence_length,
@@ -405,7 +405,7 @@ Status LaunchDecoderAttentionKernel(
405405
prop,
406406
tuning_ctx,
407407
stream,
408-
rocblas,
408+
hipblas,
409409
element_size,
410410
batch_size,
411411
sequence_length,

onnxruntime/contrib_ops/rocm/bert/attention_impl.h

Lines changed: 50 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#pragma once
55

66
#include <hip/hip_fp16.h>
7-
#include <rocblas/rocblas.h>
7+
#include <hipblas/hipblas.h>
88
#include "contrib_ops/cpu/bert/attention_common.h"
99
#include "core/providers/rocm/shared_inc/rocm_utils.h"
1010
#include "core/providers/rocm/tunable/rocm_tunable.h"
@@ -70,64 +70,59 @@ Status LaunchConcatTensorToTensor(hipStream_t stream,
7070
const half* tensor_add,
7171
half* tensor_out);
7272

73-
inline rocblas_status _compat_rocblas_gemm_strided_batched_ex(rocblas_handle handle,
74-
rocblas_operation transa,
75-
rocblas_operation transb,
76-
int m,
77-
int n,
78-
int k,
79-
const void* alpha,
80-
const void* A,
81-
rocblas_datatype a_type,
82-
rocblas_int lda,
83-
rocblas_stride stride_A,
84-
const void* b,
85-
rocblas_datatype b_type,
86-
rocblas_int ldb,
87-
rocblas_stride stride_b,
88-
const void* beta,
89-
void* c,
90-
rocblas_datatype c_type,
91-
rocblas_int ldc,
92-
rocblas_stride stride_c,
93-
rocblas_int batch_count,
94-
rocblas_datatype compute_type,
95-
rocblas_gemm_algo algo) {
96-
return rocblas_gemm_strided_batched_ex(handle,
97-
transa,
98-
transb,
99-
m, // m
100-
n, // n
101-
k, // k
102-
alpha, // alpha
103-
A, // A
104-
a_type, // A type
105-
lda, // lda
106-
stride_A, // strideA
107-
b, // B
108-
b_type, // B type
109-
ldb, // ldb
110-
stride_b, // strideB
111-
beta, // beta
112-
c, // C
113-
c_type, // C type
114-
ldc, // ldc
115-
stride_c, // strideC
116-
c, // D = C
117-
c_type, // D type = C type
118-
ldc, // ldd = ldc
119-
stride_c, // strideD = strideC
120-
batch_count, // batch count
121-
compute_type,
122-
algo,
123-
0, 0);
73+
inline hipblasStatus_t _compat_hipblas_gemm_strided_batched_ex(hipblasHandle_t handle,
74+
hipblasOperation_t transa,
75+
hipblasOperation_t transb,
76+
int m,
77+
int n,
78+
int k,
79+
const void* alpha,
80+
const void* A,
81+
hipDataType a_type,
82+
int lda,
83+
hipblasStride stride_A,
84+
const void* b,
85+
hipDataType b_type,
86+
int ldb,
87+
hipblasStride stride_b,
88+
const void* beta,
89+
void* c,
90+
hipDataType c_type,
91+
int ldc,
92+
hipblasStride stride_c,
93+
int batch_count,
94+
hipblasComputeType_t compute_type,
95+
hipblasGemmAlgo_t algo) {
96+
return hipblasGemmStridedBatchedEx(handle,
97+
transa,
98+
transb,
99+
m, // m
100+
n, // n
101+
k, // k
102+
alpha, // alpha
103+
A, // A
104+
a_type, // A type
105+
lda, // lda
106+
stride_A, // strideA
107+
b, // B
108+
b_type, // B type
109+
ldb, // ldb
110+
stride_b, // strideB
111+
beta, // beta
112+
c, // C
113+
c_type, // C type
114+
ldc, // ldc
115+
stride_c, // strideC
116+
batch_count, // batch count
117+
compute_type,
118+
algo);
124119
}
125120

126121
// Compatible for CublasMathModeSetter
127-
class CompatRocblasMathModeSetter {
122+
class CompatHipblasMathModeSetter {
128123
public:
129-
CompatRocblasMathModeSetter(const hipDeviceProp_t&,
130-
rocblas_handle,
124+
CompatHipblasMathModeSetter(const hipDeviceProp_t&,
125+
hipblasHandle_t,
131126
int) {
132127
}
133128
};

onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct GemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
3232
return MakeString("M", m, "_N", n, "_K", k, "_B", batch);
3333
}
3434

35-
rocblas_handle handle;
35+
hipblasHandle_t handle;
3636
const AttentionParameters* attention;
3737
const hipDeviceProp_t* device_prop;
3838

0 commit comments

Comments
 (0)