Skip to content

Commit 2540bd7

Browse files
committed
feat: add layer-wise KV cache H2D copy optimization.
1 parent f97dac3 commit 2540bd7

File tree

16 files changed

+185
-161
lines changed

16 files changed

+185
-161
lines changed

xllm/core/framework/kv_cache/kv_cache_store.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,14 @@ uint32_t KVCacheStore::batch_put(
128128
return block_transfer_info.size();
129129
}
130130

131-
uint64_t success_cnt = str_keys.size();
131+
uint64_t success_cnt = block_transfer_info.size() - str_keys.size();
132132
auto results = client_ptr_->BatchPut(str_keys, slices, rep_config_);
133133

134134
for (int i = 0; i < str_keys.size(); i++) {
135135
if (!results[i].has_value()) {
136-
success_cnt = i;
137-
// LOG(ERROR) << "success_cnt: " << success_cnt
138-
// << ", failed to BatchPut: " << toString(results[i].error());
139136
break;
140137
}
138+
success_cnt++;
141139
}
142140
return success_cnt;
143141
}
@@ -179,15 +177,13 @@ uint32_t KVCacheStore::batch_get(
179177
return 0;
180178
}
181179

182-
uint64_t success_cnt = str_keys.size();
180+
uint64_t success_cnt = 0;
183181
auto results = client_ptr_->BatchGet(str_keys, slices);
184182
for (int i = 0; i < str_keys.size(); i++) {
185183
if (!results[i].has_value()) {
186-
success_cnt = i;
187-
// LOG(ERROR) << "success_cnt: " << success_cnt
188-
// << ", failed to BatchGet: " << toString(results[i].error());
189184
break;
190185
}
186+
success_cnt++;
191187
}
192188
return success_cnt;
193189
}

xllm/core/framework/model/model_input_params.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ struct ModelInputParams {
119119
// Copy graph_buffer to device
120120
params.graph_buffer = safe_to(graph_buffer, device, true);
121121

122+
params.batch_id = batch_id;
123+
122124
return params;
123125
}
124126

@@ -199,6 +201,8 @@ struct ModelInputParams {
199201

200202
#if defined(USE_NPU)
201203
std::shared_ptr<NPULayerSynchronizerImpl> layer_synchronizer = nullptr;
204+
std::shared_ptr<NPULayerSynchronizerImpl> layer_wise_load_synchronizer =
205+
nullptr;
202206
#endif
203207

204208
DpEpPaddingData dp_ep_padding_data;

xllm/core/platform/device.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ int Device::synchronize_default_stream() {
9696
#endif
9797
}
9898

99-
std::unique_ptr<Stream> Device::get_stream_from_pool() {
100-
return std::make_unique<Stream>();
99+
std::unique_ptr<Stream> Device::get_stream_from_pool(const int32_t timeout) {
100+
return std::make_unique<Stream>(timeout);
101101
}
102102

103103
} // namespace xllm

xllm/core/platform/device.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class Device {
4444
int64_t free_memory();
4545

4646
int synchronize_default_stream();
47-
std::unique_ptr<Stream> get_stream_from_pool();
47+
std::unique_ptr<Stream> get_stream_from_pool(const int32_t timeout = -1);
4848

4949
private:
5050
struct DeviceMem {

xllm/core/platform/npu/npu_layer_synchronizer.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@ limitations under the License.
1919

2020
namespace xllm {
2121

22-
NPULayerSynchronizerImpl::NPULayerSynchronizerImpl(const int64_t num_layers)
23-
: events_(num_layers, nullptr), event_record_flags_(num_layers) {
22+
NPULayerSynchronizerImpl::NPULayerSynchronizerImpl(const int64_t num_layers,
23+
const int32_t timeout)
24+
: events_(num_layers, nullptr),
25+
event_record_flags_(num_layers),
26+
timeout_(timeout) {
2427
uint32_t flags = ACL_EVENT_SYNC;
2528
for (int64_t i = 0; i < num_layers; ++i) {
2629
auto ret = aclrtCreateEventWithFlag(&events_[i], flags);
@@ -45,9 +48,9 @@ std::atomic<bool>* NPULayerSynchronizerImpl::get_event_flag(
4548

4649
bool NPULayerSynchronizerImpl::synchronize_layer(const int64_t layer_index) {
4750
while (!event_record_flags_[layer_index].load(std::memory_order_acquire));
48-
auto ret = aclrtSynchronizeEvent(events_[layer_index]);
51+
auto ret = aclrtSynchronizeEventWithTimeout(events_[layer_index], timeout_);
4952
if (ret != ACL_SUCCESS) {
50-
LOG(ERROR) << "Synchronize event failed.";
53+
LOG(ERROR) << "Synchronize event failed: " << ret;
5154
return false;
5255
}
5356
return true;

xllm/core/platform/npu/npu_layer_synchronizer.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ namespace xllm {
2424

2525
class NPULayerSynchronizerImpl {
2626
public:
27-
NPULayerSynchronizerImpl(const int64_t num_layers);
27+
NPULayerSynchronizerImpl(const int64_t num_layers,
28+
const int32_t timeout = -1);
2829
virtual ~NPULayerSynchronizerImpl();
2930

3031
aclrtEvent* get_event(const int64_t layer_index);
@@ -34,6 +35,7 @@ class NPULayerSynchronizerImpl {
3435
private:
3536
std::vector<aclrtEvent> events_;
3637
std::vector<std::atomic<bool>> event_record_flags_;
38+
const int32_t timeout_;
3739
};
3840

3941
} // namespace xllm

xllm/core/platform/stream.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@ limitations under the License.
1818
namespace xllm {
1919

2020
#if defined(USE_NPU)
21-
Stream::Stream() : stream_(c10_npu::getNPUStreamFromPool()) {}
21+
Stream::Stream(const int32_t timeout)
22+
: stream_(c10_npu::getNPUStreamFromPool()), timeout_(timeout) {}
2223
#elif defined(USE_MLU)
23-
Stream::Stream() : stream_(torch_mlu::getStreamFromPool()) {}
24+
Stream::Stream(const int32_t timeout)
25+
: stream_(torch_mlu::getStreamFromPool()), timeout_(timeout) {}
2426
#endif
2527

2628
int Stream::synchronize() const {
2729
#if defined(USE_NPU)
28-
return aclrtSynchronizeStream(stream_.stream());
30+
return aclrtSynchronizeStreamWithTimeout(stream_.stream(), timeout_);
2931
#elif defined(USE_MLU)
3032
stream_.unwrap().synchronize();
3133
return 0;

xllm/core/platform/stream.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ namespace xllm {
3434

3535
class Stream {
3636
public:
37-
Stream();
37+
Stream(const int32_t timeout = -1);
3838
~Stream() = default;
3939

4040
Stream(const Stream&) = delete;
@@ -44,13 +44,19 @@ class Stream {
4444

4545
int synchronize() const;
4646
c10::StreamGuard set_stream_guard() const;
47+
#if defined(USE_NPU)
48+
c10_npu::NPUStream* get_stream() { return &stream_; }
49+
#elif defined(USE_MLU)
50+
torch_mlu::MLUStream* get_stream() { return &stream_; }
51+
#endif
4752

4853
private:
4954
#if defined(USE_NPU)
5055
c10_npu::NPUStream stream_;
5156
#elif defined(USE_MLU)
5257
torch_mlu::MLUStream stream_;
5358
#endif
59+
const int32_t timeout_;
5460
};
5561

5662
} // namespace xllm

0 commit comments

Comments
 (0)