Skip to content

Commit d704051

Browse files
authored
Port existing ROCM ukernels from HIP to C. (#19194)
It's purely device code, so it doesn't need HIP's defining feature of generating both host and device code. It can be just C code that happens to be compiled to the AMDGPU target. The flags are taken from the `users/benvanik/amdgpu` branch, `build_tools/cmake/iree_amdgpu_library.cmake`. Tested with: ``` pytest experimental/regression_suite/tests/pregenerated/test_ukernel.py -k gfx942 ``` Notes - Completely self-contained C avoids including even C standard library headers, so that we don't run into compilation failures on some CI host based on the C headers installed on it (see #19194 (comment)). - The old code used to avoid `-O3` on RDNA3 due to numerical correctness issues. Based on what we know at this point about RDNA3, numerical correctness issues are to be expected on RDNA3 and we should not refrain from -O3, instead we should relax test tolerance on RDNA3 as needed. - Types changes: - Input buffers used to be passed as (non-`const`) `T*`, changed to `const T*`. - Size parameters were passed as `size_t`. I thought let's *not* have a size type in AMDGPU ukernels, where we are dealing with multiple address spaces with different pointer widths - let's be explicit. Then the direct mapping of `size_t` would be `uint64_t` (given this is global address space), but see next point: - Switched from unsigned to signed sizes. Generally, unsigned sizes (as in `size_t`) is a legacy choice that we don't have to be bound to here, given the self-containedness of this ukernel code. And in the context of lowering MLIR to ukernel calls, the typical MLIR values that would lower to these sizes are of MLIR `index` type, and while that is signless, it is more often treated as signed than unsigned. --------- Signed-off-by: Benoit Jacob <[email protected]>
1 parent 6583762 commit d704051

File tree

4 files changed

+206
-103
lines changed

4 files changed

+206
-103
lines changed

compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt

Lines changed: 24 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,13 @@ if(NOT IREE_TARGET_BACKEND_ROCM)
77
return()
88
endif()
99

10-
# Check if HIP is installed on system.
11-
# HIP is required to compile ukernels.
12-
# TODO: We can do better than this and ensure that headers are always available.
13-
if(NOT IREE_ROCM_PATH)
14-
set(IREE_ROCM_PATH "/opt/rocm")
15-
endif()
16-
set(IREE_ROCM_VERSION "${IREE_ROCM_PATH}/include/hip/hip_version.h")
17-
if(NOT EXISTS ${IREE_ROCM_VERSION})
18-
message(STATUS
19-
"hip runtime cannot be found in ${IREE_ROCM_PATH}.
20-
Please try setting IREE_ROCM_PATH to rocm directory.
21-
Ukernels will not be compiled.")
22-
return()
23-
endif()
24-
25-
2610
iree_add_all_subdirs()
2711

2812
set(_platform_lib_reldir "iree_platform_libs/rocm")
2913
set(_device_bc_path "${IREE_COMPILER_DYLIB_DIR}/iree_platform_libs/rocm")
3014
set(_amd_ukernel_libs)
3115
set(_amd_ukernel_targets)
32-
function(iree_rocm_bitcode_library)
16+
function(iree_amdgpu_bitcode_library)
3317
cmake_parse_arguments(
3418
_RULE
3519
""
@@ -45,30 +29,29 @@ function(iree_rocm_bitcode_library)
4529
endif()
4630

4731
set(_ROCM_ARCH "${_RULE_ROCM_ARCH}")
48-
set(OPT_FLAG "-O0")
49-
if(_ROCM_ARCH MATCHES "GFX9")
50-
set(OPT_FLAG "-O3")
51-
endif()
5232
set(_COPTS
53-
"-x" "hip"
33+
# Language: C23
34+
"-x" "c"
35+
"-std=c23"
5436

55-
# Compile only the device code for the target architecture.
56-
"--offload-device-only"
57-
"--offload-arch=${_ROCM_ARCH}"
37+
# Local headers.
38+
"-I${IREE_SOURCE_DIR}"
5839

59-
# Suppress warnings about about ROCM version (we mostly don't care).
60-
"-D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH"
40+
# Avoid dependencies.
41+
"-nogpulib"
6142

62-
# Use the ROCM specified by the IREE cmake variable (instead of guessing
63-
# or failing if ROCM is not on the user's path).
64-
"--rocm-path=${IREE_ROCM_PATH}"
43+
# Avoid ABI issues.
44+
"-fno-short-wchar" # Shouldn't matter to us, but doesn't hurt.
6545

66-
# Avoid linking in default libraries as we will link them at a later phase.
67-
"-nogpulib"
46+
# Target architecture/machine.
47+
"-target" "amdgcn-amd-amdhsa"
48+
"-march=${_ROCM_ARCH}"
49+
"-fgpu-rdc" # NOTE: may not be required for all targets.
6850

69-
# Only enable necessary optimizations S.T we can use -O3.
70-
"-Xclang" "-disable-llvm-optzns"
71-
"${OPT_FLAG}"
51+
# Optimized.
52+
"-O3"
53+
"-fno-ident"
54+
"-fvisibility=hidden"
7255

7356
# Object file only in bitcode format:
7457
"-c"
@@ -77,7 +60,8 @@ function(iree_rocm_bitcode_library)
7760

7861
set(_BITCODE_FILES)
7962
foreach(_SRC ${_RULE_SRCS})
80-
get_filename_component(_BITCODE_SRC_PATH "${_SRC}" REALPATH)
63+
get_filename_component(_SRC_PATH "${_SRC}" REALPATH)
64+
get_filename_component(_COMMON_H_PATH "common.h" REALPATH)
8165
set(_BITCODE_FILE "${_RULE_NAME}_${_SRC}_${_ROCM_ARCH}.bc")
8266
list(APPEND _BITCODE_FILES ${_BITCODE_FILE})
8367
add_custom_command(
@@ -86,12 +70,13 @@ function(iree_rocm_bitcode_library)
8670
COMMAND
8771
"${IREE_CLANG_BINARY}"
8872
${_COPTS}
89-
"${_BITCODE_SRC_PATH}"
73+
"${_SRC_PATH}"
9074
"-o"
9175
"${_BITCODE_FILE}"
9276
DEPENDS
9377
"${IREE_CLANG_BINARY}"
94-
"${_SRC}"
78+
"${_SRC_PATH}"
79+
"${_COMMON_H_PATH}"
9580
COMMENT
9681
"Compiling ${_SRC} to ${_BITCODE_FILE}"
9782
VERBATIM
@@ -127,7 +112,7 @@ endfunction()
127112
# except compile-time cost, so just picked out the popular ones.
128113
set(_ukernel_supported_chips "gfx90a" "gfx942" "gfx1030" "gfx1100")
129114
foreach(_amd_chip ${_ukernel_supported_chips})
130-
iree_rocm_bitcode_library(
115+
iree_amdgpu_bitcode_library(
131116
NAME
132117
rocm_argmax_ukernel
133118
ROCM_ARCH

compiler/plugins/target/ROCM/builtins/ukernel/argmax_ukernel.c

Lines changed: 58 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,7 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
#include <float.h>
8-
#include <hip/hip_fp16.h>
9-
#include <hip/hip_runtime.h>
10-
11-
extern "C" __device__ __attribute__((const)) half __ockl_wfred_max_f16(half);
12-
extern "C" __device__
13-
__attribute__((const)) int64_t __ockl_wfred_min_i64(int64_t);
14-
extern "C" __device__
15-
__attribute__((const)) int32_t __ockl_wfred_min_i32(int32_t);
7+
#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"
168

179
/*
1810
Constraint/Tiling note:
@@ -21,27 +13,27 @@ only use single subgroup/warp per workgroup. This constraint is also set during
2113
tiling phase in KernelConfig.
2214
*/
2315

24-
extern "C" __device__ void __iree_uk_rocm_argmax_F32I32(float *inputBuffer,
25-
size_t input_offset,
26-
int32_t *outputBuffer,
27-
size_t output_offset,
28-
size_t reductionSize) {
29-
uint laneID = __builtin_amdgcn_workitem_id_x();
16+
void __iree_uk_rocm_argmax_F32I32(const float *inputBuffer,
17+
int64_t input_offset, int32_t *outputBuffer,
18+
int64_t output_offset,
19+
int64_t reductionSize) {
20+
const int warpSize = __builtin_amdgcn_wavefrontsize();
21+
int32_t laneID = __builtin_amdgcn_workitem_id_x();
3022
// Set identity value to handle problem non divisible by subgroupSize.
3123
float laneMax =
3224
laneID >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + laneID];
3325
int32_t laneResult = laneID;
3426

3527
// NOTE: On F32 kernels with clang, reductionSize/blockDim.x has numerical
3628
// inaccuracy.
37-
uint numBatches = (reductionSize + warpSize - 1) / warpSize;
29+
int32_t numBatches = (reductionSize + warpSize - 1) / warpSize;
3830
for (int i = 1; i < numBatches; ++i) {
39-
uint idx = warpSize * i + laneID;
31+
int32_t idx = warpSize * i + laneID;
4032
float newIn =
4133
idx >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + idx];
4234
if (newIn == laneMax)
4335
continue;
44-
laneMax = __ocml_fmax_f32(newIn, laneMax);
36+
laneMax = __builtin_fmaxf(newIn, laneMax);
4537
laneResult = newIn == laneMax ? idx : laneResult;
4638
}
4739

@@ -50,12 +42,12 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F32I32(float *inputBuffer,
5042
// https://github.com/iree-org/iree/issues/16112.
5143
float wgMax = laneMax;
5244
for (int i = 1; i < warpSize; i *= 2) {
53-
wgMax = __ocml_fmax_f32(__shfl_xor(wgMax, i), wgMax);
45+
wgMax = __builtin_fmaxf(__shfl_xor_f(wgMax, i), wgMax);
5446
}
5547
// Check if there are multiple max value holders.
5648
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
5749
// if there is only one max value holder, write and exit.
58-
if (__popcll(laneHasMaxValmask) == 1) {
50+
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
5951
if (wgMax == laneMax)
6052
outputBuffer[output_offset] = laneResult;
6153
return;
@@ -68,27 +60,27 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F32I32(float *inputBuffer,
6860
outputBuffer[output_offset] = laneResult;
6961
}
7062

71-
extern "C" __device__ void __iree_uk_rocm_argmax_F32I64(float *inputBuffer,
72-
size_t input_offset,
73-
int64_t *outputBuffer,
74-
size_t output_offset,
75-
size_t reductionSize) {
76-
uint laneID = __builtin_amdgcn_workitem_id_x();
63+
void __iree_uk_rocm_argmax_F32I64(const float *inputBuffer,
64+
int64_t input_offset, int64_t *outputBuffer,
65+
int64_t output_offset,
66+
int64_t reductionSize) {
67+
const int warpSize = __builtin_amdgcn_wavefrontsize();
68+
int32_t laneID = __builtin_amdgcn_workitem_id_x();
7769
// Set identity value to handle problem non divisible by subgroupSize.
7870
float laneMax =
7971
laneID >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + laneID];
8072
int64_t laneResult = laneID;
8173

8274
// NOTE: On F32 kernels with clang, reductionSize/blockDim.x has numerical
8375
// inaccuracy.
84-
uint numBatches = (reductionSize + warpSize - 1) / warpSize;
76+
int32_t numBatches = (reductionSize + warpSize - 1) / warpSize;
8577
for (int i = 1; i < numBatches; ++i) {
86-
uint idx = warpSize * i + laneID;
78+
int32_t idx = warpSize * i + laneID;
8779
float newIn =
8880
idx >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + idx];
8981
if (newIn == laneMax)
9082
continue;
91-
laneMax = __ocml_fmax_f32(newIn, laneMax);
83+
laneMax = __builtin_fmaxf(newIn, laneMax);
9284
laneResult = newIn == laneMax ? idx : laneResult;
9385
}
9486

@@ -97,57 +89,58 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F32I64(float *inputBuffer,
9789
// https://github.com/iree-org/iree/issues/16112.
9890
float wgMax = laneMax;
9991
for (int i = 1; i < warpSize; i *= 2) {
100-
wgMax = __ocml_fmax_f32(__shfl_xor(wgMax, i), wgMax);
92+
wgMax = __builtin_fmaxf(__shfl_xor_f(wgMax, i), wgMax);
10193
}
10294
// Check if there are multiple max value holders.
10395
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
10496
// if there is only one max value holder, write and exit.
105-
if (__popcll(laneHasMaxValmask) == 1) {
97+
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
10698
if (wgMax == laneMax)
10799
outputBuffer[output_offset] = laneResult;
108100
return;
109101
}
110102
// if there are multiple max value holder, find smallest index (argmax
111103
// semantics).
112-
int64_t indexVal = wgMax == laneMax ? laneResult : __INT64_MAX__;
104+
int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX;
113105
laneResult = __ockl_wfred_min_i64(indexVal);
114106
if (laneID == 0)
115107
outputBuffer[output_offset] = laneResult;
116108
}
117109

118-
extern "C" __device__ void __iree_uk_rocm_argmax_F16I32(half *inputBuffer,
119-
size_t input_offset,
120-
int32_t *outputBuffer,
121-
size_t output_offset,
122-
size_t reductionSize) {
123-
half NEG_F16_MAX = __float2half(-65504.0f);
124-
uint laneID = __builtin_amdgcn_workitem_id_x();
110+
void __iree_uk_rocm_argmax_F16I32(const _Float16 *inputBuffer,
111+
int64_t input_offset, int32_t *outputBuffer,
112+
int64_t output_offset,
113+
int64_t reductionSize) {
114+
const int warpSize = __builtin_amdgcn_wavefrontsize();
115+
_Float16 NEG_F16_MAX = (_Float16)(-65504.0f);
116+
int32_t laneID = __builtin_amdgcn_workitem_id_x();
125117
// Set identity value to handle problem non divisible by subgroupSize.
126-
half laneMax = laneID >= reductionSize ? NEG_F16_MAX
127-
: inputBuffer[input_offset + laneID];
118+
_Float16 laneMax = laneID >= reductionSize
119+
? NEG_F16_MAX
120+
: inputBuffer[input_offset + laneID];
128121
int32_t laneResult = laneID;
129122

130-
uint numBatches = (reductionSize + warpSize - 1) / warpSize;
123+
int32_t numBatches = (reductionSize + warpSize - 1) / warpSize;
131124
for (int i = 1; i < numBatches; ++i) {
132-
uint idx = warpSize * i + laneID;
133-
half newIn =
125+
int32_t idx = warpSize * i + laneID;
126+
_Float16 newIn =
134127
idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx];
135128
if (newIn == laneMax)
136129
continue;
137-
laneMax = __ocml_fmax_f16(newIn, laneMax);
130+
laneMax = __builtin_fmaxf16(newIn, laneMax);
138131
laneResult = newIn == laneMax ? idx : laneResult;
139132
}
140-
141133
// Final reduction with one subgroup
142-
half wgMax = __ockl_wfred_max_f16(laneMax);
134+
_Float16 wgMax = __ockl_wfred_max_f16(laneMax);
143135
// Check if there are multiple max value holders.
144136
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
145137
// if there is only one max value holder, write and exit.
146-
if (__popcll(laneHasMaxValmask) == 1) {
138+
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
147139
if (wgMax == laneMax)
148140
outputBuffer[output_offset] = laneResult;
149141
return;
150142
}
143+
151144
// if there are multiple max value holder, find smallest index (argmax
152145
// semantics).
153146
int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__;
@@ -156,42 +149,43 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F16I32(half *inputBuffer,
156149
outputBuffer[output_offset] = laneResult;
157150
}
158151

159-
extern "C" __device__ void __iree_uk_rocm_argmax_F16I64(half *inputBuffer,
160-
size_t input_offset,
161-
int64_t *outputBuffer,
162-
size_t output_offset,
163-
size_t reductionSize) {
164-
half NEG_F16_MAX = __float2half(-65504.0f);
165-
uint laneID = __builtin_amdgcn_workitem_id_x();
152+
void __iree_uk_rocm_argmax_F16I64(const _Float16 *inputBuffer,
153+
int64_t input_offset, int64_t *outputBuffer,
154+
int64_t output_offset,
155+
int64_t reductionSize) {
156+
const int warpSize = __builtin_amdgcn_wavefrontsize();
157+
_Float16 NEG_F16_MAX = (_Float16)(-65504.0f);
158+
int32_t laneID = __builtin_amdgcn_workitem_id_x();
166159
// Set identity value to handle problem non divisible by subgroupSize.
167-
half laneMax = laneID >= reductionSize ? NEG_F16_MAX
168-
: inputBuffer[input_offset + laneID];
160+
_Float16 laneMax = laneID >= reductionSize
161+
? NEG_F16_MAX
162+
: inputBuffer[input_offset + laneID];
169163
int64_t laneResult = laneID;
170164

171-
uint numBatches = (reductionSize + warpSize - 1) / warpSize;
165+
int32_t numBatches = (reductionSize + warpSize - 1) / warpSize;
172166
for (int i = 1; i < numBatches; ++i) {
173-
uint idx = warpSize * i + laneID;
174-
half newIn =
167+
int32_t idx = warpSize * i + laneID;
168+
_Float16 newIn =
175169
idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx];
176170
if (newIn == laneMax)
177171
continue;
178-
laneMax = __ocml_fmax_f16(newIn, laneMax);
172+
laneMax = __builtin_fmaxf16(newIn, laneMax);
179173
laneResult = newIn == laneMax ? idx : laneResult;
180174
}
181175

182176
// Final reduction with one subgroup
183-
half wgMax = __ockl_wfred_max_f16(laneMax);
177+
_Float16 wgMax = __ockl_wfred_max_f16(laneMax);
184178
// Check if there are multiple max value holders.
185179
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
186180
// if there is only one max value holder, write and exit.
187-
if (__popcll(laneHasMaxValmask) == 1) {
181+
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
188182
if (wgMax == laneMax)
189183
outputBuffer[output_offset] = laneResult;
190184
return;
191185
}
192186
// if there are multiple max value holder, find smallest index (argmax
193187
// semantics).
194-
int64_t indexVal = wgMax == laneMax ? laneResult : __INT64_MAX__;
188+
int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX;
195189
laneResult = __ockl_wfred_min_i64(indexVal);
196190
if (laneID == 0)
197191
outputBuffer[output_offset] = laneResult;

0 commit comments

Comments
 (0)