Skip to content

Commit dd20f55

Browse files
authored
bugfix: fix shared memory alignment conflict in sampling.cuh (#1402)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Fix compilation issue on gb200 due to shared mem alignment conflict. ## πŸ” Related Issues #1400 ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 7e3eb8d commit dd20f55

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

β€Žinclude/flashinfer/sampling.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,8 +739,8 @@ __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType*
739739
const uint32_t row_idx = indices == nullptr ? bx : indices[bx];
740740
using SharedMem = typename BlockReduce<DataAndIndex<DType, IdType>, BLOCK_THREADS,
741741
REDUCE_ALGORITHM>::TempStorage;
742-
extern __shared__ __align__(alignof(SharedMem)) uint8_t smem_sampling[];
743-
auto& temp_storage = reinterpret_cast<SharedMem&>(smem_sampling);
742+
extern __shared__ __align__(alignof(SharedMem)) uint8_t smem_sampling_logit[];
743+
auto& temp_storage = reinterpret_cast<SharedMem&>(smem_sampling_logit);
744744

745745
vec_t<DType, VEC_SIZE> logits_vec;
746746
DataAndIndex<DType, IdType> max_data = {-cuda::std::numeric_limits<DType>::infinity(), 0};

0 commit comments

Comments
Β (0)