@@ -51,7 +51,7 @@ struct TopKRedType
5151 static __host__ __device__ inline TypeCmp makeCmpVal (T val, int32_t idx = 0 )
5252 {
5353 auto valueBits = cub::Traits<T>::TwiddleIn (reinterpret_cast <typename cub::Traits<T>::UnsignedBits&>(val));
54- TypeCmp compactTmp = reinterpret_cast <TypeCmp&>( valueBits) ;
54+ TypeCmp compactTmp = valueBits;
5555 compactTmp = (compactTmp << kMoveBits ) | (0xFFFF & (kMaxIdx - idx));
5656 // Use 65535 minus idx to give higher priority to elements with smaller indices.
5757 return compactTmp;
@@ -162,9 +162,28 @@ struct Sort<4, RedType>
162162 }
163163};
164164
165+ template <int K, typename Type>
166+ __forceinline__ __device__ void reduceTopK (cg::thread_block_tile<kWARP_SIZE > const & warp, Type (&out)[K],
167+ int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue, int actualK = K)
168+ {
169+ static_assert (K > 0 , " Top K must have K > 0" );
170+ static_assert (K < kWARP_SIZE , " Top K must have K < kWARP_SIZE" );
171+ using RedType = TopKRedType<Type>;
172+ RedType topK{value, idx};
173+ typename RedType::TypeCmp packedMax{};
174+ #pragma unroll
175+ for (int kk = 0 ; kk < actualK; ++kk) // @todo: check if actualK is correct
176+ {
177+ topK = kk > 0 && packedMax == topK.compValIdx ? RedType{minValue, idx} : topK;
178+ // get the next largest value
179+ packedMax = topK.reduce (warp);
180+ RedType::unpack (out[kk], outIdx[kk], packedMax);
181+ }
182+ };
183+
165184template <int K, typename Type, int N, bool IsSorted = false >
166- __device__ void reduceTopK (cg::thread_block_tile<kWARP_SIZE > const & warp, Type (&out)[K], int32_t (&outIdx)[K],
167- Type (&value)[N], int32_t (&idx)[N], Type minValue)
185+ __device__ void reduceTopKFunc (cg::thread_block_tile<kWARP_SIZE > const & warp, Type (&out)[K], int32_t (&outIdx)[K],
186+ Type (&value)[N], int32_t (&idx)[N], Type minValue, int actualK = K )
168187{
169188 static_assert (K > 0 , " Top K must have K > 0" );
170189 static_assert (K < kWARP_SIZE , " Top K must have K < kWARP_SIZE" );
@@ -184,7 +203,7 @@ __device__ void reduceTopK(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (
184203 }
185204 typename RedType::TypeCmp packedMax{};
186205#pragma unroll
187- for (int kk = 0 ; kk < K ; ++kk)
206+ for (int kk = 0 ; kk < actualK ; ++kk)
188207 {
189208 bool update = kk > 0 && packedMax == topK[0 ].compValIdx ;
190209#pragma unroll
@@ -198,6 +217,67 @@ __device__ void reduceTopK(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (
198217 }
199218};
200219
220+ template <int K, typename Type, int N>
221+ __forceinline__ __device__ void reduceTopK (cg::thread_block_tile<kWARP_SIZE > const & warp, Type (&out)[K],
222+ int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K)
223+ {
224+ static_assert (K > 0 , " Top K must have K > 0" );
225+ static_assert (K < kWARP_SIZE , " Top K must have K < kWARP_SIZE" );
226+ static_assert (N > 0 , " Top K must have N > 0" );
227+ static_assert (N <= 16 , " Only support candidates number less than or equal to 16*32=512" );
228+ static_assert (
229+ N <= 4 || N % 4 == 0 , " Only support candidates number is a multiple of 4*32=128 or less than or equal to 4" );
230+ using RedType = TopKRedType<Type>;
231+
232+ if constexpr (N <= 4 )
233+ {
234+ reduceTopKFunc<K, Type, N>(warp, out, outIdx, value, idx, minValue, actualK);
235+ }
236+ else
237+ {
238+
239+ constexpr int numLoops = N / 4 ;
240+ constexpr int numResults = (numLoops * K - 1 ) / kWARP_SIZE + 1 ;
241+
242+ Type topKBufferValue[numResults];
243+ int32_t topKBufferIdx[numResults];
244+ int32_t laneIdx = threadIdx .x % kWARP_SIZE ;
245+
246+ for (int ii = 0 ; ii < numResults; ++ii)
247+ {
248+ topKBufferValue[ii] = minValue;
249+ topKBufferIdx[ii] = ii * kWARP_SIZE - 1 ; // @todo: check if this is correct
250+ }
251+ for (int loop = 0 ; loop < numLoops; ++loop)
252+ {
253+ int start = loop * 4 ;
254+ Type topKValue[K];
255+ int32_t topKIdx[K];
256+ Type inValue[4 ];
257+ int32_t inIdx[4 ];
258+ for (int i = 0 ; i < 4 ; ++i)
259+ {
260+ inValue[i] = value[start + i];
261+ inIdx[i] = idx[start + i];
262+ }
263+ reduceTopKFunc<K, Type, 4 >(warp, topKValue, topKIdx, inValue, inIdx, minValue, actualK);
264+ int inOffset = laneIdx % K;
265+ if (laneIdx >= loop * K && laneIdx < (loop + 1 ) * K)
266+ {
267+ topKBufferValue[0 ] = topKValue[inOffset];
268+ topKBufferIdx[0 ] = topKIdx[inOffset];
269+ }
270+ if (loop == numLoops - 1 && (laneIdx < (numLoops * K - kWARP_SIZE )))
271+ {
272+ topKBufferValue[1 ] = topKValue[inOffset];
273+ topKBufferIdx[1 ] = topKIdx[inOffset];
274+ }
275+ }
276+
277+ reduceTopKFunc<K, Type, numResults>(warp, out, outIdx, topKBufferValue, topKBufferIdx, minValue, actualK);
278+ }
279+ };
280+
201281#undef TOPK_SWAP
202282
203283} // namespace reduce_topk
0 commit comments