Skip to content

Commit bbf4035

Browse files
authored
feat: integrate xqa attention backend (#1503)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Signed-off-by: Qidi Sang <[email protected]>
1 parent 185e048 commit bbf4035

20 files changed

+7225
-0
lines changed

csrc/flashinfer_xqa_ops.cu

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "pytorch_extension_utils.h"
18+
19+
void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize,
20+
double qScale, at::Tensor output,
21+
#if LOW_PREC_OUTPUT
22+
at::Tensor rcpOutScale,
23+
#endif
24+
at::Tensor q, at::Tensor attentionSinks, at::Tensor pool,
25+
at::Tensor kvCachePageList, int64_t maxSeqLen, at::Tensor seqLen,
26+
int64_t batchSize, at::Tensor kvCacheScale,
27+
#if SPEC_DEC
28+
int64_t qSeqLen, at::Tensor qCuSeqLens, at::Tensor mask,
29+
#endif
30+
at::Tensor semaphores, at::Tensor scratch);
31+
32+
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
33+
// "XQA Wrapper"
34+
m.def("xqa_wrapper", xqa_wrapper);
35+
}

0 commit comments

Comments
 (0)