Skip to content

Commit a533db2

Browse files
committed
bugfix: add workspace for flashinfer.
1 parent f745add commit a533db2

File tree

8 files changed

+122
-5
lines changed

8 files changed

+122
-5
lines changed

xllm/core/common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
rate_limiter.h
1616
types.h
1717
device_monitor.h
18+
flashinfer_workspace.h
1819
SRCS
1920
etcd_client.cpp
2021
global_flags.cpp
@@ -23,6 +24,7 @@ cc_library(
2324
options.cpp
2425
rate_limiter.cpp
2526
device_monitor.cpp
27+
flashinfer_workspace.cpp
2628
DEPS
2729
util
2830
absl::random_random
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "flashinfer_workspace.h"
17+
18+
#include "global_flags.h"
19+
20+
namespace xllm {
21+
22+
void FlashinferWorkspace::initialize(const torch::Device& device) {
23+
float_workspace_buffer_ =
24+
torch::empty({FLAGS_workspace_buffer_size},
25+
torch::dtype(torch::kUInt8).device(device));
26+
int_workspace_buffer_ =
27+
torch::empty({FLAGS_workspace_buffer_size},
28+
torch::dtype(torch::kUInt8).device(device));
29+
page_locked_int_workspace_buffer_ = torch::empty(
30+
{FLAGS_workspace_buffer_size},
31+
torch::dtype(torch::kUInt8).device(torch::kCPU).pinned_memory(true));
32+
}
33+
34+
torch::Tensor FlashinferWorkspace::get_float_workspace_buffer() {
35+
return float_workspace_buffer_;
36+
}
37+
38+
torch::Tensor FlashinferWorkspace::get_int_workspace_buffer() {
39+
return int_workspace_buffer_;
40+
}
41+
42+
torch::Tensor FlashinferWorkspace::get_page_locked_int_workspace_buffer() {
43+
return page_locked_int_workspace_buffer_;
44+
}
45+
46+
} // namespace xllm
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <torch/torch.h>
19+
20+
#include <cstdint>
21+
22+
#include "macros.h"
23+
24+
namespace xllm {
25+
26+
class FlashinferWorkspace {
27+
public:
28+
static FlashinferWorkspace& get_instance() {
29+
static FlashinferWorkspace instance;
30+
return instance;
31+
};
32+
33+
void initialize(const torch::Device& device);
34+
35+
torch::Tensor get_float_workspace_buffer();
36+
torch::Tensor get_int_workspace_buffer();
37+
torch::Tensor get_page_locked_int_workspace_buffer();
38+
39+
private:
40+
FlashinferWorkspace() = default;
41+
~FlashinferWorkspace() = default;
42+
DISALLOW_COPY_AND_ASSIGN(FlashinferWorkspace);
43+
44+
torch::Tensor float_workspace_buffer_;
45+
torch::Tensor int_workspace_buffer_;
46+
torch::Tensor page_locked_int_workspace_buffer_;
47+
};
48+
49+
} // namespace xllm

xllm/core/common/global_flags.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ DEFINE_string(store_metadata_connstring,
343343
"",
344344
"The address of the kv cache store metadata service.");
345345

346-
// --- for computation communication parallel ---
346+
// --- computation communication parallel config ---
347347

348348
DEFINE_bool(
349349
enable_multi_stream_parallel,
@@ -355,7 +355,7 @@ DEFINE_int32(default_micro_batch_num,
355355
2,
356356
"Default use two micro batches for multi-stream parallel.");
357357

358-
// --- for dit ---
358+
// --- dit config ---
359359
DEFINE_int32(max_requests_per_batch, 1, "Max number of request per batch.");
360360

361361
// --- continuous kv cache config ---
@@ -378,3 +378,9 @@ DEFINE_int64(cache_size_per_token,
378378
DEFINE_int64(buffer_size_per_seq,
379379
0,
380380
"Buffer size per sequence in bytes, default 0.");
381+
382+
// --- flashinfer config ---
383+
DEFINE_int32(workspace_buffer_size,
384+
512 * 1024 * 1024,
385+
"The user reserved workspace buffer used to store intermediate "
386+
"attention results in split-k algorithm.");

xllm/core/common/global_flags.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,6 @@ DECLARE_int32(max_global_ttft_ms);
189189

190190
DECLARE_int32(max_global_tpot_ms);
191191

192-
// dit
193192
DECLARE_int32(max_requests_per_batch);
194193

195194
DECLARE_bool(enable_continuous_kvcache);
@@ -199,3 +198,5 @@ DECLARE_int64(granularity_size);
199198
DECLARE_int64(cache_size_per_token);
200199

201200
DECLARE_int64(buffer_size_per_seq);
201+
202+
DECLARE_int32(workspace_buffer_size);

xllm/core/kernels/ops_api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ void batch_decode(AttentionParams& params) {
140140
params.k_cache,
141141
params.output,
142142
params.block_table,
143-
params.seq_lens,
143+
params.kv_seq_lens,
144144
params.v_cache,
145145
params.output_lse,
146146
params.q_quant_scale,

xllm/core/layers/mlu/attention.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include "attention.h"
1717

18+
#include "common/flashinfer_workspace.h"
1819
#include "kernels/ops_api.h"
1920

2021
DECLARE_bool(enable_chunked_prefill);
@@ -99,6 +100,14 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> AttentionImpl::forward(
99100
attention_params.window_size_left = sliding_window_;
100101
attention_params.scale = scale_;
101102
attention_params.compute_dtype = attn_metadata.compute_dtype;
103+
// for flashinfer
104+
attention_params.float_workspace_buffer =
105+
FlashinferWorkspace::get_instance().get_float_workspace_buffer();
106+
attention_params.int_workspace_buffer =
107+
FlashinferWorkspace::get_instance().get_int_workspace_buffer();
108+
attention_params.page_locked_int_workspace_buffer =
109+
FlashinferWorkspace::get_instance()
110+
.get_page_locked_int_workspace_buffer();
102111
attention_params.kv_cu_seq_lens = attn_metadata.kv_cu_seq_lens;
103112
attention_params.q_cu_seq_lens = attn_metadata.q_cu_seq_lens;
104113

xllm/core/runtime/llm_worker_impl.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License.
2626
#include <utility>
2727

2828
#include "common/device_monitor.h"
29+
#include "common/flashinfer_workspace.h"
2930
#include "common/metrics.h"
3031
#include "common/types.h"
3132
#include "core/common/global_flags.h"
@@ -41,7 +42,10 @@ namespace xllm {
4142
LLMWorkerImpl::LLMWorkerImpl(const ParallelArgs& parallel_args,
4243
const torch::Device& device,
4344
const runtime::Options& options)
44-
: WorkerImpl(parallel_args, device, options) {}
45+
: WorkerImpl(parallel_args, device, options) {
46+
// initialize flashinfer workspace
47+
FlashinferWorkspace::get_instance().initialize(device_);
48+
}
4549

4650
bool LLMWorkerImpl::init_model(ModelContext& context) {
4751
CHECK(model_ == nullptr) << "Model is already initialized.";

0 commit comments

Comments
 (0)