Skip to content

Commit f174268

Browse files
ClementLinCFCopilotvalarLipMHYangAMD
authored
feat: Adaptive topk algorithm selection based on input characteristics (ROCm#1578)
* Add radix-base selection * Remove explicit template * Update the selected k condition * remove pos < k guard * code format * Update csrc/include/rocm_ops.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update csrc/kernels/topk_per_row_kernels.cu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update csrc/kernels/topk_plain_kernels.cu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update test_topk_plain.py * Update TODO message * Update csrc/kernels/topk_per_row_kernels.cu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update op_tests/test_topk_plain.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * format test_topk_plain.py with black * Disable triton test for a resonalbe execution time * add explicit template instantiation * fix explicit template instantiation * add explicit template instantiation * Add bf16 support * Fix linter * Fix build errors * Fix condition * Fix build and test * Update conditions --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Co-authored-by: MHYang <meng-hsuan.yang@amd.com>
1 parent 420f5da commit f174268

File tree

8 files changed

+1212
-409
lines changed

8 files changed

+1212
-409
lines changed

aiter/jit/optCompilerConfig.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,8 @@
10771077
"module_topk_plain": {
10781078
"srcs": [
10791079
"f'{AITER_CSRC_DIR}/pybind/topk_plain_pybind.cu'",
1080-
"f'{AITER_CSRC_DIR}/kernels/topk_plain_kernels.cu'"
1080+
"f'{AITER_CSRC_DIR}/kernels/topk_plain_kernels.cu'",
1081+
"f'{AITER_CSRC_DIR}/kernels/topk_per_row_kernels.cu'"
10811082
],
10821083
"flags_extra_cc": [],
10831084
"flags_extra_hip": [],

aiter/ops/topk_plain.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
def topk_plain(
1414
x: torch.Tensor,
1515
topk_ids: torch.Tensor,
16+
topk_out: torch.Tensor,
1617
topk: int,
17-
largest: bool,
18+
largest: bool = True,
19+
rowStarts: torch.Tensor = None,
20+
rowEnds: torch.Tensor = None,
21+
stride0: int = -1,
22+
stride1: int = 1,
1823
) -> None:
1924
pass

csrc/include/opus/opus.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ template<> OPUS_D float min<float>(const float&a, const float&b) { return
907907

908908
template<typename T> OPUS_D T med3(const T&a, const T&b, const T&c) { auto max_0 = max(a, b); auto min_0 = max(a, b); return max(max_0, max(min_0, c)); }
909909
template<> OPUS_D float med3<float>(const float&a, const float&b, const float&c) { return __builtin_amdgcn_fmed3f(a, b, c); }
910-
template<> OPUS_D __fp16 med3<__fp16>(const __fp16&a, const __fp16&b, const __fp16&c) { return __builtin_amdgcn_fmed3h(a, b, c); }
910+
template<> OPUS_D _Float16 med3<_Float16>(const _Float16&a, const _Float16&b, const _Float16&c) { return __builtin_amdgcn_fmed3h(a, b, c); }
911911
/////////////////////////////////////////////////////////////////////////////////////////////////////////
912912
// buffer load/store related
913913
OPUS_D constexpr auto buffer_default_config() {

csrc/include/rocm_ops.hpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,10 +1635,15 @@ namespace py = pybind11;
16351635
py::arg("final_output"), \
16361636
py::arg("final_lse") = std::nullopt);
16371637

1638-
#define TOPK_PLAIN_PYBIND \
1639-
m.def("topk_plain", \
1640-
&topk_plain, \
1641-
py::arg("values"), \
1642-
py::arg("topk_ids"), \
1643-
py::arg("topk"), \
1644-
py::arg("largest"));
1638+
#define TOPK_PLAIN_PYBIND \
1639+
m.def("topk_plain", \
1640+
&topk_plain, \
1641+
py::arg("values"), \
1642+
py::arg("topk_ids"), \
1643+
py::arg("topk_out"), \
1644+
py::arg("topk"), \
1645+
py::arg("largest") = true, \
1646+
py::arg("rowStarts") = torch::Tensor(), \
1647+
py::arg("rowEnds") = torch::Tensor(), \
1648+
py::arg("stride0") = -1, \
1649+
py::arg("stride1") = 1);

csrc/include/topk_plain.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,10 @@
66

77
void topk_plain(torch::Tensor& values,
88
torch::Tensor& topk_ids,
9-
int topk_num,
10-
bool largest);
9+
torch::Tensor& topk_out,
10+
int topk,
11+
bool largest = true,
12+
torch::Tensor rowStarts = torch::Tensor(),
13+
torch::Tensor rowEnds = torch::Tensor(),
14+
int64_t stride0 = -1,
15+
int64_t stride1 = 1);

csrc/kernels/topk_per_row_kernels.cu

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,8 @@ __device__ void filter_and_histogram(T const* in_buf,
420420
IdxT* histogram,
421421
bool select_min,
422422
int pass,
423-
bool early_stop)
423+
bool early_stop,
424+
IdxT k)
424425
{
425426
constexpr int num_buckets = calc_num_buckets<BitsPerPass>();
426427
__shared__ IdxT histogram_smem[num_buckets];
@@ -893,9 +894,19 @@ __global__ void radix_kernel(T const* in,
893894
int const pass)
894895
{
895896
const int64_t batch_id = blockIdx.y;
896-
const IdxT row_len = phase == Phase::Prefill
897-
? rowEnds[batch_id] - rowStarts[batch_id]
898-
: rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1;
897+
898+
IdxT row_len = len;
899+
if(phase == Phase::Prefill)
900+
{
901+
if(rowStarts && rowEnds)
902+
{
903+
row_len = rowEnds[batch_id] - rowStarts[batch_id];
904+
}
905+
}
906+
else
907+
{
908+
row_len = rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1;
909+
}
899910

900911
auto counter = counters + batch_id;
901912
IdxT current_k;
@@ -965,7 +976,8 @@ __global__ void radix_kernel(T const* in,
965976
histogram,
966977
select_min,
967978
pass,
968-
early_stop);
979+
early_stop,
980+
k);
969981
__threadfence();
970982

971983
bool isLastBlock = false;
@@ -1187,7 +1199,8 @@ __device__ bool filter_and_histogram_for_one_block(T const* in_buf,
11871199
Counter<T, IdxT>* counter,
11881200
IdxT* histogram,
11891201
bool select_min,
1190-
int pass)
1202+
int pass,
1203+
IdxT k)
11911204
{
11921205
constexpr int num_buckets = calc_num_buckets<BitsPerPass>();
11931206
for(int i = threadIdx.x; i < num_buckets * 2; i += blockDim.x)
@@ -1371,11 +1384,25 @@ __global__ void radix_topk_one_block_kernel(T const* in,
13711384
__shared__ IdxT histogram[num_buckets * 2];
13721385

13731386
const int64_t batch_id = blockIdx.x;
1374-
const IdxT rowStart = phase == Phase::Prefill ? rowStarts[batch_id] : 0;
1375-
const IdxT rowEnd = phase == Phase::Prefill
1376-
? rowEnds[batch_id]
1377-
: rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1;
1378-
const IdxT row_len = rowEnd - rowStart;
1387+
1388+
IdxT rowStart = 0;
1389+
IdxT rowEnd = len;
1390+
if(phase == Phase::Prefill)
1391+
{
1392+
if(rowStarts && rowEnds)
1393+
{
1394+
rowStart = rowStarts[batch_id];
1395+
rowEnd = rowEnds[batch_id];
1396+
}
1397+
}
1398+
else
1399+
{
1400+
rowEnd = rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1;
1401+
rowStart = 0;
1402+
}
1403+
1404+
const IdxT row_len = rowEnd - rowStart;
1405+
13791406
if(threadIdx.x == 0)
13801407
{
13811408
counter.k = k;
@@ -1448,7 +1475,8 @@ __global__ void radix_topk_one_block_kernel(T const* in,
14481475
&counter,
14491476
histogram,
14501477
select_min,
1451-
pass); //@TODO CHECK UPDATE CODE
1478+
pass,
1479+
k); //@TODO CHECK UPDATE CODE
14521480
__syncthreads();
14531481

14541482
scan<IdxT, BitsPerPass, BlockSize>(histogram + use_one_pass * num_buckets);
@@ -1811,6 +1839,35 @@ void standalone_stable_radix_11bits(void* buf,
18111839
}
18121840
}
18131841

1842+
// Explicit template instantiation for standalone_stable_radix_11bits
1843+
template void standalone_stable_radix_11bits<float, int, true, true>(void* buf,
1844+
size_t& buf_size,
1845+
float const* in,
1846+
int batch_size,
1847+
int64_t len,
1848+
int* rowStarts,
1849+
int* rowEnds,
1850+
int k,
1851+
float* out,
1852+
int* out_idx,
1853+
bool greater,
1854+
hipStream_t stream,
1855+
int next_n);
1856+
1857+
template void standalone_stable_radix_11bits<float, int, false, true>(void* buf,
1858+
size_t& buf_size,
1859+
float const* in,
1860+
int batch_size,
1861+
int64_t len,
1862+
int* rowStarts,
1863+
int* rowEnds,
1864+
int k,
1865+
float* out,
1866+
int* out_idx,
1867+
bool greater,
1868+
hipStream_t stream,
1869+
int next_n);
1870+
18141871
// AIR TopK end
18151872

18161873
static inline __device__ uint32_t floatAsSortableUint(float x)
@@ -2410,6 +2467,9 @@ int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0)
24102467
return buf_size;
24112468
}
24122469

2470+
// Explicit template instantiation to ensure the symbol is available for linking
2471+
template int64_t invokeComputeTopkLastDimWorkspaceSize<float>(int32_t numRows, int32_t stride0);
2472+
24132473
void top_k_per_row_prefill(const torch::Tensor& logits,
24142474
const torch::Tensor& rowStarts,
24152475
const torch::Tensor& rowEnds,

0 commit comments

Comments
 (0)