You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: examples/hstu/inference/README.md
+9-25Lines changed: 9 additions & 25 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -2,9 +2,9 @@
2
2
3
3
## Key Features
4
4
5
-
1. Cache for KV data
5
+
1.Asynchronous Cache Manager for KV data
6
6
7
-
We use GPU memory and host storage for KV data cache., as in `GpuKVCacheManager` and `HostKVStorageManager`. This can help to reduce the recomputation of KV data.
7
+
We use GPU memory and host storage for KV data cache as in `AsyncKVCacheManager`. This can help to reduce the recomputation of KV data. All the kvcache related operations are implemented as asynchronous, in order to hide the overhead with inference computation.
8
8
9
9
The GPU KV cache is organized as a paged KV-data table, and supports KV data adding/appending, lookup and eviction. When appending new data to the GPU cache, we will evict data from the oldest users according to the LRU policy if there is no empty page. The HSTU attention kernel also accepts KV data from a paged table.
10
10
@@ -33,33 +33,16 @@ The dense module is served as one instance per GPU, and the KV cache is not supp
33
33
### KVCache Usage
34
34
35
35
1. KVCache Manager supports the following operations:
36
-
*`get_user_kvdata_info`: to get current cached length and index of the first cached tokens in the history sequence
37
-
*`prepare_kv_cache`: to allocate the required cache pages. The input history sequence need to be
38
-
*`paged_kvcache_ops.append_kvcache`: the cuda kernel to copy the `K, V` values into the allocated cache pages
39
-
*`offload_kv_cache`: to offload the KV data from GPU KVCache to Host KV storage.
36
+
*`prepare_kvcache_async`: to trigger the allocation for required KV cache pages, kvcache_metadata computation, and onload the KV data from Host KV storage to GPU KVCache in background.
37
+
*`prepare_kvcache_wait`: to wait the new KV cache pages allocation and kvcache_metadata computation.
38
+
*`paged_kvcache_ops.append_kvcache`: the cuda kernel to copy the `K, V` values into the allocated cache pages.
39
+
*`offload_kvcache`: to trigger offloading the KV data from GPU KVCache to Host KV storage in background.
40
40
*`evict_kv_cache`: to evict all the KV data in the KVCache Manager.
41
41
42
-
2. Currently, the KVCache manager need to be access from a single thread.
42
+
2. Currently, the KVCache manager need to be access from a single inference stream. No multi-stream support.
43
43
44
-
3.For different requests, the call to `get_user_kvdata_info` and `prepare_kv_cache` need to be in order and cannot be interleaved. Since the allocation in `prepare_kv_cache` may evict the cached data of other users, which changes the user kvdata_info.
44
+
3.The KVCache manager accepts full user history sequence as input. The removal of cached tokens in sequences is completed within inference forward pass.
45
45
46
-
4. The KVCache manager does not support uncontinuous user history sequence as input from the same user. The overlapping tokens need to be removed before sending the sequence to the inference model. Doing the overrlapping removal in the upstream stage should be more performant than in the inference model.
47
-
48
-
```
49
-
[current KV data in cache] userId: 0, starting position: 0, cached length: 10
Copy file name to clipboardExpand all lines: examples/hstu/inference/benchmark/README.md
+3-2Lines changed: 3 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -36,7 +36,7 @@ Here we benchmarked with a synthetic input dataset:
36
36
* Each input request has 256 item candidates for ranking.
37
37
* Generate data for 1, 2, 4 and 8 users to benchmark with different batch size.
38
38
39
-
We can achieve **1.4x ~ 2.7x** performance speedup for inference (with batch size ranging from 1 to 8), after utilizing the KV cache and CUDA graph optimization.
39
+
We can achieve **1.3x ~ 2.6x** performance speedup for inference (with batch size ranging from 1 to 8), after utilizing the KV cache and CUDA graph optimization.
40
40
41
41
Performance results:
42
42
@@ -46,7 +46,8 @@ Note:
46
46
47
47
1. The baseline performance is based on our implementation without KVCache support and CUDA Graph optimization.
48
48
2. The end-to-end performance includes the embedding part, which utilizes both native `EmbeddingCollection` from TorchRec and `DynamicEmbedding`.
49
-
3. The number of input sequences from the synthetic dataset increases according to the batch size.
49
+
3. The number of input sequences from the synthetic dataset increases according to the batch size. All test cases have 16 batches in total.
50
+
4. In the test cases with KVCache enabled, the kvcache preparation and onloading/offloading are within time measurement, but they are hidden as asynchronous operations.
0 commit comments