44// See https://llvm.org/LICENSE.txt for license information.
55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
7- #include <float.h>
8- #include <hip/hip_fp16.h>
9- #include <hip/hip_runtime.h>
10-
11- extern "C" __device__ __attribute__((const )) half __ockl_wfred_max_f16 (half );
12- extern "C" __device__
13- __attribute__((const )) int64_t __ockl_wfred_min_i64 (int64_t );
14- extern "C" __device__
15- __attribute__((const )) int32_t __ockl_wfred_min_i32 (int32_t );
7+ #include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"
168
179/*
1810Constraint/Tiling note:
@@ -21,27 +13,27 @@ only use single subgroup/warp per workgroup. This constraint is also set during
2113tiling phase in KernelConfig.
2214*/
2315
24- extern "C" __device__ void __iree_uk_rocm_argmax_F32I32 (float * inputBuffer ,
25- size_t input_offset ,
26- int32_t * outputBuffer ,
27- size_t output_offset ,
28- size_t reductionSize ) {
29- uint laneID = __builtin_amdgcn_workitem_id_x ();
16+ void __iree_uk_rocm_argmax_F32I32 (const float * inputBuffer ,
17+ int64_t input_offset , int32_t * outputBuffer ,
18+ int64_t output_offset ,
19+ int64_t reductionSize ) {
20+ const int warpSize = __builtin_amdgcn_wavefrontsize ();
21+ int32_t laneID = __builtin_amdgcn_workitem_id_x ();
3022 // Set identity value to handle problem non divisible by subgroupSize.
3123 float laneMax =
3224 laneID >= reductionSize ? - FLT_MAX : inputBuffer [input_offset + laneID ];
3325 int32_t laneResult = laneID ;
3426
3527 // NOTE: On F32 kernels with clang, reductionSize/blockDim.x has numerical
3628 // inaccuracy.
37- uint numBatches = (reductionSize + warpSize - 1 ) / warpSize ;
29+ int32_t numBatches = (reductionSize + warpSize - 1 ) / warpSize ;
3830 for (int i = 1 ; i < numBatches ; ++ i ) {
39- uint idx = warpSize * i + laneID ;
31+ int32_t idx = warpSize * i + laneID ;
4032 float newIn =
4133 idx >= reductionSize ? - FLT_MAX : inputBuffer [input_offset + idx ];
4234 if (newIn == laneMax )
4335 continue ;
44- laneMax = __ocml_fmax_f32 (newIn , laneMax );
36+ laneMax = __builtin_fmaxf (newIn , laneMax );
4537 laneResult = newIn == laneMax ? idx : laneResult ;
4638 }
4739
@@ -50,12 +42,12 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F32I32(float *inputBuffer,
5042 // https://github.com/iree-org/iree/issues/16112.
5143 float wgMax = laneMax ;
5244 for (int i = 1 ; i < warpSize ; i *= 2 ) {
53- wgMax = __ocml_fmax_f32 ( __shfl_xor (wgMax , i ), wgMax );
45+ wgMax = __builtin_fmaxf ( __shfl_xor_f (wgMax , i ), wgMax );
5446 }
5547 // Check if there are multiple max value holders.
5648 uint64_t laneHasMaxValmask = __ballot (wgMax == laneMax );
5749 // if there is only one max value holder, write and exit.
58- if (__popcll (laneHasMaxValmask ) == 1 ) {
50+ if (__builtin_popcountll (laneHasMaxValmask ) == 1 ) {
5951 if (wgMax == laneMax )
6052 outputBuffer [output_offset ] = laneResult ;
6153 return ;
@@ -68,27 +60,27 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F32I32(float *inputBuffer,
6860 outputBuffer [output_offset ] = laneResult ;
6961}
7062
71- extern "C" __device__ void __iree_uk_rocm_argmax_F32I64 (float * inputBuffer ,
72- size_t input_offset ,
73- int64_t * outputBuffer ,
74- size_t output_offset ,
75- size_t reductionSize ) {
76- uint laneID = __builtin_amdgcn_workitem_id_x ();
63+ void __iree_uk_rocm_argmax_F32I64 (const float * inputBuffer ,
64+ int64_t input_offset , int64_t * outputBuffer ,
65+ int64_t output_offset ,
66+ int64_t reductionSize ) {
67+ const int warpSize = __builtin_amdgcn_wavefrontsize ();
68+ int32_t laneID = __builtin_amdgcn_workitem_id_x ();
7769 // Set identity value to handle problem non divisible by subgroupSize.
7870 float laneMax =
7971 laneID >= reductionSize ? - FLT_MAX : inputBuffer [input_offset + laneID ];
8072 int64_t laneResult = laneID ;
8173
8274 // NOTE: On F32 kernels with clang, reductionSize/blockDim.x has numerical
8375 // inaccuracy.
84- uint numBatches = (reductionSize + warpSize - 1 ) / warpSize ;
76+ int32_t numBatches = (reductionSize + warpSize - 1 ) / warpSize ;
8577 for (int i = 1 ; i < numBatches ; ++ i ) {
86- uint idx = warpSize * i + laneID ;
78+ int32_t idx = warpSize * i + laneID ;
8779 float newIn =
8880 idx >= reductionSize ? - FLT_MAX : inputBuffer [input_offset + idx ];
8981 if (newIn == laneMax )
9082 continue ;
91- laneMax = __ocml_fmax_f32 (newIn , laneMax );
83+ laneMax = __builtin_fmaxf (newIn , laneMax );
9284 laneResult = newIn == laneMax ? idx : laneResult ;
9385 }
9486
@@ -97,57 +89,58 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F32I64(float *inputBuffer,
9789 // https://github.com/iree-org/iree/issues/16112.
9890 float wgMax = laneMax ;
9991 for (int i = 1 ; i < warpSize ; i *= 2 ) {
100- wgMax = __ocml_fmax_f32 ( __shfl_xor (wgMax , i ), wgMax );
92+ wgMax = __builtin_fmaxf ( __shfl_xor_f (wgMax , i ), wgMax );
10193 }
10294 // Check if there are multiple max value holders.
10395 uint64_t laneHasMaxValmask = __ballot (wgMax == laneMax );
10496 // if there is only one max value holder, write and exit.
105- if (__popcll (laneHasMaxValmask ) == 1 ) {
97+ if (__builtin_popcountll (laneHasMaxValmask ) == 1 ) {
10698 if (wgMax == laneMax )
10799 outputBuffer [output_offset ] = laneResult ;
108100 return ;
109101 }
110102 // if there are multiple max value holder, find smallest index (argmax
111103 // semantics).
112- int64_t indexVal = wgMax == laneMax ? laneResult : __INT64_MAX__ ;
104+ int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX ;
113105 laneResult = __ockl_wfred_min_i64 (indexVal );
114106 if (laneID == 0 )
115107 outputBuffer [output_offset ] = laneResult ;
116108}
117109
118- extern "C" __device__ void __iree_uk_rocm_argmax_F16I32 (half * inputBuffer ,
119- size_t input_offset ,
120- int32_t * outputBuffer ,
121- size_t output_offset ,
122- size_t reductionSize ) {
123- half NEG_F16_MAX = __float2half (-65504.0f );
124- uint laneID = __builtin_amdgcn_workitem_id_x ();
110+ void __iree_uk_rocm_argmax_F16I32 (const _Float16 * inputBuffer ,
111+ int64_t input_offset , int32_t * outputBuffer ,
112+ int64_t output_offset ,
113+ int64_t reductionSize ) {
114+ const int warpSize = __builtin_amdgcn_wavefrontsize ();
115+ _Float16 NEG_F16_MAX = ( _Float16 ) (-65504.0f );
116+ int32_t laneID = __builtin_amdgcn_workitem_id_x ();
125117 // Set identity value to handle problem non divisible by subgroupSize.
126- half laneMax = laneID >= reductionSize ? NEG_F16_MAX
127- : inputBuffer [input_offset + laneID ];
118+ _Float16 laneMax = laneID >= reductionSize
119+ ? NEG_F16_MAX
120+ : inputBuffer [input_offset + laneID ];
128121 int32_t laneResult = laneID ;
129122
130- uint numBatches = (reductionSize + warpSize - 1 ) / warpSize ;
123+ int32_t numBatches = (reductionSize + warpSize - 1 ) / warpSize ;
131124 for (int i = 1 ; i < numBatches ; ++ i ) {
132- uint idx = warpSize * i + laneID ;
133- half newIn =
125+ int32_t idx = warpSize * i + laneID ;
126+ _Float16 newIn =
134127 idx >= reductionSize ? NEG_F16_MAX : inputBuffer [input_offset + idx ];
135128 if (newIn == laneMax )
136129 continue ;
137- laneMax = __ocml_fmax_f16 (newIn , laneMax );
130+ laneMax = __builtin_fmaxf16 (newIn , laneMax );
138131 laneResult = newIn == laneMax ? idx : laneResult ;
139132 }
140-
141133 // Final reduction with one subgroup
142- half wgMax = __ockl_wfred_max_f16 (laneMax );
134+ _Float16 wgMax = __ockl_wfred_max_f16 (laneMax );
143135 // Check if there are multiple max value holders.
144136 uint64_t laneHasMaxValmask = __ballot (wgMax == laneMax );
145137 // if there is only one max value holder, write and exit.
146- if (__popcll (laneHasMaxValmask ) == 1 ) {
138+ if (__builtin_popcountll (laneHasMaxValmask ) == 1 ) {
147139 if (wgMax == laneMax )
148140 outputBuffer [output_offset ] = laneResult ;
149141 return ;
150142 }
143+
151144 // if there are multiple max value holder, find smallest index (argmax
152145 // semantics).
153146 int32_t indexVal = wgMax == laneMax ? laneResult : __INT32_MAX__ ;
@@ -156,42 +149,43 @@ extern "C" __device__ void __iree_uk_rocm_argmax_F16I32(half *inputBuffer,
156149 outputBuffer [output_offset ] = laneResult ;
157150}
158151
159- extern "C" __device__ void __iree_uk_rocm_argmax_F16I64 (half * inputBuffer ,
160- size_t input_offset ,
161- int64_t * outputBuffer ,
162- size_t output_offset ,
163- size_t reductionSize ) {
164- half NEG_F16_MAX = __float2half (-65504.0f );
165- uint laneID = __builtin_amdgcn_workitem_id_x ();
152+ void __iree_uk_rocm_argmax_F16I64 (const _Float16 * inputBuffer ,
153+ int64_t input_offset , int64_t * outputBuffer ,
154+ int64_t output_offset ,
155+ int64_t reductionSize ) {
156+ const int warpSize = __builtin_amdgcn_wavefrontsize ();
157+ _Float16 NEG_F16_MAX = ( _Float16 ) (-65504.0f );
158+ int32_t laneID = __builtin_amdgcn_workitem_id_x ();
166159 // Set identity value to handle problem non divisible by subgroupSize.
167- half laneMax = laneID >= reductionSize ? NEG_F16_MAX
168- : inputBuffer [input_offset + laneID ];
160+ _Float16 laneMax = laneID >= reductionSize
161+ ? NEG_F16_MAX
162+ : inputBuffer [input_offset + laneID ];
169163 int64_t laneResult = laneID ;
170164
171- uint numBatches = (reductionSize + warpSize - 1 ) / warpSize ;
165+ int32_t numBatches = (reductionSize + warpSize - 1 ) / warpSize ;
172166 for (int i = 1 ; i < numBatches ; ++ i ) {
173- uint idx = warpSize * i + laneID ;
174- half newIn =
167+ int32_t idx = warpSize * i + laneID ;
168+ _Float16 newIn =
175169 idx >= reductionSize ? NEG_F16_MAX : inputBuffer [input_offset + idx ];
176170 if (newIn == laneMax )
177171 continue ;
178- laneMax = __ocml_fmax_f16 (newIn , laneMax );
172+ laneMax = __builtin_fmaxf16 (newIn , laneMax );
179173 laneResult = newIn == laneMax ? idx : laneResult ;
180174 }
181175
182176 // Final reduction with one subgroup
183- half wgMax = __ockl_wfred_max_f16 (laneMax );
177+ _Float16 wgMax = __ockl_wfred_max_f16 (laneMax );
184178 // Check if there are multiple max value holders.
185179 uint64_t laneHasMaxValmask = __ballot (wgMax == laneMax );
186180 // if there is only one max value holder, write and exit.
187- if (__popcll (laneHasMaxValmask ) == 1 ) {
181+ if (__builtin_popcountll (laneHasMaxValmask ) == 1 ) {
188182 if (wgMax == laneMax )
189183 outputBuffer [output_offset ] = laneResult ;
190184 return ;
191185 }
192186 // if there are multiple max value holder, find smallest index (argmax
193187 // semantics).
194- int64_t indexVal = wgMax == laneMax ? laneResult : __INT64_MAX__ ;
188+ int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX ;
195189 laneResult = __ockl_wfred_min_i64 (indexVal );
196190 if (laneID == 0 )
197191 outputBuffer [output_offset ] = laneResult ;
0 commit comments