Skip to content

Commit cda8f3f

Browse files
authored
feat: IdType indices in sampling kernels (flashinfer-ai#2281)
<!-- .github/pull_request_template.md --> ## 📌 Description Based on this [comment](flashinfer-ai#2127 (review)) in flashinfer-ai#2127, we can add support for Int64 indices as well. I decided to do this using `IdType` like it is done in other files. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). Test results: ``` (flashinfer) raayan@uril-1:~/projects/flashinfer$ pytest tests/utils/test_sampling.py ============================================================= test session starts ============================================================= platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0 rootdir: /home/raayan/projects/flashinfer configfile: pytest.ini collected 1884 items tests/utils/test_sampling.py .......................................................................................................... [ 5%] ....................................................................................................................................... [ 12%] ....................................................................................................................................... [ 19%] ....................s..s..s..........................................................................sss........................sss.... [ 27%] ....................................................................................................................................... [ 34%] ..........................ssss................................ssss................................ssss................................s [ 41%] sss................................ssss................................ssss................................ssss........................ [ 48%] ........ssss................................ssss................................ssss................................ssss............... [ 55%] .................ssss................................ssss................................ssss................................ssss...... [ 62%] ..........................ssss................................ssss................................ssss................................s [ 70%] sss................................ssss................................ssss................................ssss........................ [ 77%] ........ssss................................ssss................................ssss................................ssss............... [ 84%] .................ssss.................................................................................................................. [ 91%] ........................................................sss............................................................................ [ 98%] ....................... [100%] ================================================ 1764 passed, 120 skipped in 546.33s (0:09:06) ================================================ (flashinfer) raayan@uril-1:~/projects/flashinfer$ ``` ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Signed-off-by: raayandhar <[email protected]>
1 parent b09fbcb commit cda8f3f

File tree

4 files changed

+196
-88
lines changed

4 files changed

+196
-88
lines changed

csrc/sampling.cu

Lines changed: 86 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -49,36 +49,48 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorV
4949
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
5050
CHECK_INPUT(logits);
5151
CHECK_DIM(2, logits); // logits: (batch_size, vocab_size)
52-
CHECK_MAYBE_INPUT_TYPE(maybe_indices, dl_int32);
52+
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
53+
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
5354
unsigned int batch_size = output.size(0);
5455
unsigned int vocab_size = logits.size(1);
5556

5657
ffi::CUDADeviceGuard device_guard(logits.device().device_id);
5758
auto stream = get_stream(logits.device());
58-
cudaError_t status = sampling::SamplingFromLogits(
59-
static_cast<float*>(logits.data_ptr()), static_cast<int*>(output.data_ptr()),
60-
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value().data_ptr()) : nullptr,
61-
batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
62-
TVM_FFI_ICHECK(status == cudaSuccess)
63-
<< "SamplingFromLogits failed with error code " << cudaGetErrorString(status);
59+
60+
DISPATCH_DLPACK_IDTYPE_TO_CTYPE(output.dtype(), IdType, [&] {
61+
cudaError_t status = sampling::SamplingFromLogits<float, IdType>(
62+
static_cast<float*>(logits.data_ptr()), static_cast<IdType*>(output.data_ptr()),
63+
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
64+
: nullptr,
65+
batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
66+
TVM_FFI_ICHECK(status == cudaSuccess)
67+
<< "SamplingFromLogits failed with error code " << cudaGetErrorString(status);
68+
return true;
69+
});
6470
}
6571

6672
void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorView> maybe_indices,
6773
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
6874
CHECK_INPUT(probs);
6975
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
70-
CHECK_MAYBE_INPUT_TYPE(maybe_indices, dl_int32);
76+
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
77+
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
7178
unsigned int batch_size = output.size(0);
7279
unsigned int vocab_size = probs.size(1);
7380

7481
ffi::CUDADeviceGuard device_guard(probs.device().device_id);
7582
auto stream = get_stream(probs.device());
76-
cudaError_t status = sampling::SamplingFromProb(
77-
static_cast<float*>(probs.data_ptr()), static_cast<int*>(output.data_ptr()),
78-
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value().data_ptr()) : nullptr,
79-
batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
80-
TVM_FFI_ICHECK(status == cudaSuccess)
81-
<< "SamplingFromProbs failed with error code " << cudaGetErrorString(status);
83+
84+
DISPATCH_DLPACK_IDTYPE_TO_CTYPE(output.dtype(), IdType, [&] {
85+
cudaError_t status = sampling::SamplingFromProb<float, IdType>(
86+
static_cast<float*>(probs.data_ptr()), static_cast<IdType*>(output.data_ptr()),
87+
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
88+
: nullptr,
89+
batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
90+
TVM_FFI_ICHECK(status == cudaSuccess)
91+
<< "SamplingFromProbs failed with error code " << cudaGetErrorString(status);
92+
return true;
93+
});
8294
}
8395

8496
void top_p_sampling_from_probs(TensorView probs, TensorView output,
@@ -87,21 +99,27 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
8799
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
88100
CHECK_INPUT(probs);
89101
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
90-
CHECK_MAYBE_INPUT_TYPE(maybe_indices, dl_int32);
102+
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
103+
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
91104
unsigned int batch_size = output.size(0);
92105
unsigned int vocab_size = probs.size(1);
93106
check_tensor_param(maybe_top_p_arr, probs);
94107
bool has_top_p_arr = maybe_top_p_arr.has_value();
95108

96109
ffi::CUDADeviceGuard device_guard(probs.device().device_id);
97110
auto stream = get_stream(probs.device());
98-
cudaError_t status = sampling::TopPSamplingFromProb<float, int>(
99-
static_cast<float*>(probs.data_ptr()), static_cast<int*>(output.data_ptr()),
100-
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value().data_ptr()) : nullptr,
101-
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value().data_ptr()) : nullptr, batch_size,
102-
top_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
103-
TVM_FFI_ICHECK(status == cudaSuccess)
104-
<< "TopPSamplingFromProbs failed with error code " << cudaGetErrorString(status);
111+
112+
DISPATCH_DLPACK_IDTYPE_TO_CTYPE(output.dtype(), IdType, [&] {
113+
cudaError_t status = sampling::TopPSamplingFromProb<float, IdType>(
114+
static_cast<float*>(probs.data_ptr()), static_cast<IdType*>(output.data_ptr()),
115+
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
116+
: nullptr,
117+
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value().data_ptr()) : nullptr,
118+
batch_size, top_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
119+
TVM_FFI_ICHECK(status == cudaSuccess)
120+
<< "TopPSamplingFromProbs failed with error code " << cudaGetErrorString(status);
121+
return true;
122+
});
105123
}
106124

107125
void top_k_sampling_from_probs(TensorView probs, TensorView output,
@@ -113,21 +131,27 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
113131
CHECK_DEVICE(output, probs);
114132
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
115133
CHECK_DIM(1, output); // output: (batch_size)
116-
CHECK_MAYBE_INPUT_TYPE(maybe_indices, dl_int32);
134+
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
135+
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
117136
unsigned int batch_size = output.size(0);
118137
unsigned int vocab_size = probs.size(1);
119138
check_tensor_param(maybe_top_k_arr, probs);
120139
bool has_top_k_arr = maybe_top_k_arr.has_value();
121140

122141
ffi::CUDADeviceGuard device_guard(probs.device().device_id);
123142
auto stream = get_stream(probs.device());
124-
cudaError_t status = sampling::TopKSamplingFromProb<float, int>(
125-
static_cast<float*>(probs.data_ptr()), static_cast<int*>(output.data_ptr()),
126-
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value().data_ptr()) : nullptr,
127-
has_top_k_arr ? static_cast<float*>(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size,
128-
top_k_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
129-
TVM_FFI_ICHECK(status == cudaSuccess)
130-
<< "TopKSamplingFromProbs failed with error code " << cudaGetErrorString(status);
143+
144+
DISPATCH_DLPACK_IDTYPE_TO_CTYPE(output.dtype(), IdType, [&] {
145+
cudaError_t status = sampling::TopKSamplingFromProb<float, IdType>(
146+
static_cast<float*>(probs.data_ptr()), static_cast<IdType*>(output.data_ptr()),
147+
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
148+
: nullptr,
149+
has_top_k_arr ? static_cast<float*>(maybe_top_k_arr.value().data_ptr()) : nullptr,
150+
batch_size, top_k_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
151+
TVM_FFI_ICHECK(status == cudaSuccess)
152+
<< "TopKSamplingFromProbs failed with error code " << cudaGetErrorString(status);
153+
return true;
154+
});
131155
}
132156

133157
void min_p_sampling_from_probs(TensorView probs, TensorView output,
@@ -139,22 +163,28 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output,
139163
CHECK_DEVICE(output, probs);
140164
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
141165
CHECK_DIM(1, output); // output: (batch_size)
142-
CHECK_MAYBE_INPUT_TYPE(maybe_indices, dl_int32);
166+
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
167+
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
143168
unsigned int batch_size = output.size(0);
144169
unsigned int vocab_size = probs.size(1);
145170
check_tensor_param(maybe_min_p_arr, probs);
146171
bool has_min_p_arr = maybe_min_p_arr.has_value();
147172

148173
ffi::CUDADeviceGuard device_guard(probs.device().device_id);
149174
auto stream = get_stream(probs.device());
150-
cudaError_t status = sampling::MinPSamplingFromProb<float, int>(
151-
static_cast<float*>(probs.data_ptr()),
152-
has_min_p_arr ? static_cast<float*>(maybe_min_p_arr.value().data_ptr()) : nullptr,
153-
static_cast<int*>(output.data_ptr()),
154-
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value().data_ptr()) : nullptr,
155-
batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
156-
TVM_FFI_ICHECK(status == cudaSuccess)
157-
<< "MinPSamplingFromProb failed with error code " << cudaGetErrorString(status);
175+
176+
DISPATCH_DLPACK_IDTYPE_TO_CTYPE(output.dtype(), IdType, [&] {
177+
cudaError_t status = sampling::MinPSamplingFromProb<float, IdType>(
178+
static_cast<float*>(probs.data_ptr()),
179+
has_min_p_arr ? static_cast<float*>(maybe_min_p_arr.value().data_ptr()) : nullptr,
180+
static_cast<IdType*>(output.data_ptr()),
181+
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
182+
: nullptr,
183+
batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
184+
TVM_FFI_ICHECK(status == cudaSuccess)
185+
<< "MinPSamplingFromProb failed with error code " << cudaGetErrorString(status);
186+
return true;
187+
});
158188
}
159189

160190
void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
@@ -168,7 +198,8 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
168198
CHECK_DEVICE(output, probs);
169199
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
170200
CHECK_DIM(1, output); // output: (batch_size)
171-
CHECK_MAYBE_INPUT_TYPE(maybe_indices, dl_int32);
201+
CHECK_MAYBE_INPUT_TYPES(maybe_indices, dl_int32, dl_int64);
202+
CHECK_MAYBE_SAME_DTYPE(maybe_indices, output);
172203
unsigned int batch_size = output.size(0);
173204
unsigned int vocab_size = probs.size(1);
174205
check_tensor_param(maybe_top_k_arr, probs);
@@ -178,16 +209,21 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
178209

179210
ffi::CUDADeviceGuard device_guard(probs.device().device_id);
180211
auto stream = get_stream(probs.device());
181-
cudaError_t status = sampling::TopKTopPSamplingFromProb<float, int>(
182-
static_cast<float*>(probs.data_ptr()),
183-
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value().data_ptr()) : nullptr,
184-
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value().data_ptr()) : nullptr,
185-
static_cast<int*>(output.data_ptr()),
186-
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value().data_ptr()) : nullptr,
187-
batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset,
188-
stream);
189-
TVM_FFI_ICHECK(status == cudaSuccess)
190-
<< "TopKTopPSamplingFromProbs failed with error code " << cudaGetErrorString(status);
212+
213+
DISPATCH_DLPACK_IDTYPE_TO_CTYPE(output.dtype(), IdType, [&] {
214+
cudaError_t status = sampling::TopKTopPSamplingFromProb<float, IdType>(
215+
static_cast<float*>(probs.data_ptr()),
216+
has_top_k_arr ? static_cast<IdType*>(maybe_top_k_arr.value().data_ptr()) : nullptr,
217+
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value().data_ptr()) : nullptr,
218+
static_cast<IdType*>(output.data_ptr()),
219+
maybe_indices.has_value() ? static_cast<IdType*>(maybe_indices.value().data_ptr())
220+
: nullptr,
221+
batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset,
222+
stream);
223+
TVM_FFI_ICHECK(status == cudaSuccess)
224+
<< "TopKTopPSamplingFromProbs failed with error code " << cudaGetErrorString(status);
225+
return true;
226+
});
191227
}
192228

193229
void chain_speculative_sampling(TensorView draft_probs, TensorView draft_token_ids,

csrc/tvm_ffi_utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,18 @@ inline void check_shape(const tvm::ffi::TensorView& a, const tvm::ffi::TensorVie
276276
if (maybe_x.has_value()) { \
277277
CHECK_INPUT_TYPE(maybe_x.value(), st); \
278278
}
279+
#define CHECK_MAYBE_INPUT_TYPES(maybe_x, st1, st2) \
280+
if (maybe_x.has_value()) { \
281+
TVM_FFI_ICHECK(maybe_x.value().dtype() == st1 || maybe_x.value().dtype() == st2) \
282+
<< "Inconsistency of Tensor type: " #maybe_x " must be " #st1 " or " #st2; \
283+
}
284+
#define CHECK_SAME_DTYPE(x, y) \
285+
TVM_FFI_ICHECK(x.dtype() == y.dtype()) \
286+
<< "Inconsistency of Tensor type: " #x " dtype must match " #y " dtype";
287+
#define CHECK_MAYBE_SAME_DTYPE(maybe_x, y) \
288+
if (maybe_x.has_value()) { \
289+
CHECK_SAME_DTYPE(maybe_x.value(), y); \
290+
}
279291
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
280292
CHECK_CUDA(x); \
281293
CHECK_LAST_DIM_CONTIGUOUS(x)

0 commit comments

Comments
 (0)