Skip to content

Commit fd2c33e

Browse files
HSTU KV Cache Manager V2 (#251)
* Finalized async kvcache manager implementation * Fix scaling_seqlen for inference * Formatting the code * Fix ops build * Temporary fix of Dockerfile for CI * Fixes for CI and reviews * Fix * Fix tritonserver hstu_model * Fix dist init in dynamicemb inference load * Reduce num_sm computing * Add checks for cuda API calls * Update documents and benchmark results. * Clean up unused initialized data * Fix --------- Co-authored-by: Junyi Qiu <junyiq@nvidia.com>
1 parent 02ce8a2 commit fd2c33e

29 files changed

+3008
-1476
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ repos:
1515
- id: codespell
1616
args:
1717
- --skip=".git,corelib/hstu/*,third_party/*"
18-
- --ignore-words-list=TE,TBE,tbe,dout,retrival,IndexT
18+
- --ignore-words-list=TE,TBE,tbe,dout,retrival,IndexT,ANS,ans
1919
- repo: https://github.com/psf/black
2020
rev: 23.9.1
2121
hooks:

corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1153,7 +1153,8 @@ def load(
11531153
meta_json_file = encode_meta_json_file_path(save_dir, table_name)
11541154

11551155
if isinstance(storage, DynamicEmbeddingTable) and not storage._use_score:
1156-
dist.barrier() # sync global timestamp
1156+
if dist.is_initialized():
1157+
dist.barrier() # sync global timestamp
11571158
cast(DynamicEmbeddingTable, storage).update_timestamp()
11581159
num_key_files = len(emb_key_files)
11591160
for i in range(num_key_files):

docker/Dockerfile

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,7 @@ RUN ARCH=$([ "${TARGETPLATFORM}" = "linux/arm64" ] && echo "aarch64" || echo "x8
2323
if [ ${ARCH} = "aarch64" ]; then \
2424
ln -s /usr/local/cuda-12.9/targets/sbsa-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \
2525
else \
26-
if [[ "${INFERENCEBUILD}" == "1" && "${TRITONSERVER_BUILD}" != "1" ]]; then \
27-
ln -s /usr/local/cuda-12.8/targets/${ARCH}-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \
28-
else \
29-
ln -s /usr/local/cuda-12.9/targets/${ARCH}-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \
30-
fi \
26+
ln -s /usr/local/cuda-12.9/targets/${ARCH}-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \
3127
fi
3228

3329
RUN if [ "${INFERENCEBUILD}" != "1" ]; then \
@@ -50,6 +46,7 @@ RUN pip install --no-deps tensordict orjson && \
5046

5147
RUN pip install nvidia-cutlass-dsl==4.3.0
5248

49+
5350
# for dev
5451
RUN apt update -y --fix-missing && \
5552
apt install -y gdb && \
@@ -81,9 +78,11 @@ COPY . .
8178
RUN cd /workspace/recsys-examples/corelib/dynamicemb && \
8279
python setup.py install
8380

84-
RUN if [ "${TRITONSERVER_BUILD}" = "1" ]; then \
85-
pip3 uninstall -y hstu_attn hstu_hopper; \
86-
fi
87-
81+
RUN cd /workspace/deps && rm -rf nvcomp && \
82+
wget https://developer.download.nvidia.com/compute/nvcomp/redist/nvcomp/linux-x86_64/nvcomp-linux-x86_64-5.1.0.21_cuda12-archive.tar.xz && \
83+
tar -xf nvcomp-linux-x86_64-5.1.0.21_cuda12-archive.tar.xz && \
84+
mv nvcomp-linux-x86_64-5.1.0.21_cuda12-archive nvcomp && \
85+
rm nvcomp-linux-x86_64-5.1.0.21_cuda12-archive.tar.xz
86+
8887
RUN cd /workspace/recsys-examples/examples/commons && \
8988
TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0" python3 setup.py install

examples/commons/datasets/inference_dataset.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,25 @@ def __iter__(self) -> Iterator[HSTUBatch]:
160160
)
161161
dates.append(self._batch_logs_frame.iloc[sample_id][self._date_name])
162162
seq_endptrs.append(seq_endptr)
163-
if len(user_ids) == 0:
164-
continue
163+
164+
last_date = dates[0]
165+
final_user_ids: List[int] = []
166+
final_dates: List[int] = []
167+
final_seq_endptrs: List[int] = []
168+
for uid, date, endp in zip(user_ids, dates, seq_endptrs):
169+
if date != last_date:
170+
continue
171+
if uid not in final_user_ids:
172+
final_user_ids.append(uid)
173+
final_dates.append(date)
174+
final_seq_endptrs.append(endp)
175+
else:
176+
idx = final_user_ids.index(uid)
177+
final_seq_endptrs[idx] = max(final_seq_endptrs[idx], endp)
165178
yield (
166-
torch.tensor(user_ids),
167-
torch.tensor(dates),
168-
torch.tensor(seq_endptrs),
179+
torch.tensor(final_user_ids),
180+
torch.tensor(final_dates),
181+
torch.tensor(final_seq_endptrs),
169182
)
170183

171184
def get_input_batch(
@@ -306,7 +319,7 @@ def get_input_batch(
306319
labels = torch.tensor(labels, dtype=torch.int64, device=self._device)
307320
batch_kwargs = dict(
308321
features=features,
309-
batch_size=self._batch_size,
322+
batch_size=len(user_ids),
310323
feature_to_max_seqlen=feature_to_max_seqlen,
311324
contextual_feature_names=self._contextual_feature_names,
312325
item_feature_name=self._item_feature_name,

examples/commons/datasets/random_inference_dataset.py

Lines changed: 103 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import random
16-
from typing import Dict, List, Optional
15+
from typing import Dict, Iterator, List, Tuple
1716

1817
import torch
18+
from torch.utils.data.dataset import IterableDataset
1919
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
2020

2121
from .hstu_batch import FeatureConfig, HSTUBatch
2222

2323

24-
class RandomInferenceDataGenerator:
24+
class RandomInferenceDataset(
25+
IterableDataset[Tuple[HSTUBatch, torch.Tensor, torch.Tensor]]
26+
):
2527
"""
2628
A random generator for the inference batches
2729
@@ -32,12 +34,12 @@ class RandomInferenceDataGenerator:
3234
action_feature_name (str): The action feature name.
3335
max_num_users (int): The maximum user numbers.
3436
max_batch_size (int): The maximum batch size.
35-
max_seqlen (int): The maximum sequence length (with candidates) for item
36-
in request per user. The length of action sequence in
37-
request the same with that of HISTORY item sequence.
37+
max_history_length (int): The maximum history length for item in request per user.
38+
The length of action sequence in request is the same.
3839
max_num_candidates (int): The maximum candidates number.
3940
max_incremental_seqlen (int): The maximum incremental length of HISTORY
4041
item AND action sequence.
42+
max_num_cached_batches (int, optional): The number of batches to generate. Defaults to 1.
4143
full_mode (bool): The flag for full batch mode.
4244
"""
4345

@@ -49,9 +51,10 @@ def __init__(
4951
action_feature_name: str = "",
5052
max_num_users: int = 1,
5153
max_batch_size: int = 32,
52-
max_seqlen: int = 4096,
54+
max_history_length: int = 4096,
5355
max_num_candidates: int = 200,
5456
max_incremental_seqlen: int = 64,
57+
max_num_cached_batches: int = 1,
5558
full_mode: bool = False,
5659
):
5760
super().__init__()
@@ -72,125 +75,104 @@ def __init__(
7275

7376
self._max_num_users = min(max_num_users, 2**16)
7477
self._max_batch_size = max_batch_size
75-
self._max_hist_len = max_seqlen - max_num_candidates
76-
self._max_incr_fea_len = max(max_incremental_seqlen, 1)
78+
self._max_hist_len = max_history_length
7779
self._max_num_candidates = max_num_candidates
80+
self._max_incr_fea_len = max(max_incremental_seqlen, 1)
81+
self._num_generated_batches = max(max_num_cached_batches, 1)
7882

7983
self._full_mode = full_mode
8084

8185
self._item_history: Dict[int, torch.Tensor] = dict()
8286
self._action_history: Dict[int, torch.Tensor] = dict()
8387

84-
def get_inference_batch_user_ids(self) -> Optional[torch.Tensor]:
85-
if self._full_mode:
86-
batch_size = self._max_batch_size
87-
user_ids = list(range(self._max_batch_size))
88-
else:
89-
batch_size = random.randint(1, self._max_batch_size)
90-
user_ids = torch.randint(self._max_num_users, (batch_size,)).tolist()
91-
user_ids = list(set(user_ids))
92-
93-
user_ids = torch.tensor(
94-
[
95-
uid
96-
for uid in user_ids
97-
if uid not in self._item_history
98-
or len(self._item_history[uid]) < self._max_hist_len
99-
]
100-
).long()
101-
if self._full_mode and len(user_ids) == 0:
102-
batch_size = self._max_batch_size
103-
user_ids = list(
104-
range(
105-
self._max_batch_size,
106-
min(self._max_batch_size * 2, self._max_num_users),
88+
num_cached_batches = 0
89+
self._cached_batch = list()
90+
for seqlen_idx in range(
91+
max_incremental_seqlen, self._max_hist_len, max_incremental_seqlen
92+
):
93+
for idx in range(0, self._max_num_users, self._max_batch_size):
94+
if self._full_mode:
95+
user_ids = list(
96+
range(idx, min(self._max_num_users, idx + self._max_batch_size))
97+
)
98+
else:
99+
user_ids = torch.randint(
100+
self._max_num_users, (self._max_batch_size,)
101+
).tolist()
102+
user_ids = list(set(user_ids))
103+
104+
batch_size = len(user_ids)
105+
106+
item_seq = list()
107+
action_seq = list()
108+
for uid in user_ids:
109+
if uid not in self._item_history or uid not in self._action_history:
110+
self._item_history[uid] = torch.randint(
111+
self._max_item_id + 1,
112+
(self._max_hist_len + self._max_num_candidates,),
113+
)
114+
self._action_history[uid] = torch.randint(
115+
self._max_action_id + 1,
116+
(self._max_hist_len + self._max_num_candidates,),
117+
)
118+
119+
item_seq.append(
120+
self._item_history[uid][: seqlen_idx + self._max_num_candidates]
121+
)
122+
action_seq.append(self._action_history[uid][:seqlen_idx])
123+
features = KeyedJaggedTensor.from_jt_dict(
124+
{
125+
self._item_fea_name: JaggedTensor.from_dense(item_seq),
126+
self._action_fea_name: JaggedTensor.from_dense(action_seq),
127+
}
128+
)
129+
130+
if self._full_mode:
131+
num_candidates = torch.full((batch_size,), self._max_num_candidates)
132+
else:
133+
num_candidates = torch.randint(
134+
low=1, high=self._max_num_candidates + 1, size=(batch_size,)
135+
)
136+
137+
total_history_lengths = torch.full((batch_size,), seqlen_idx * 2)
138+
139+
batch = HSTUBatch(
140+
features=features,
141+
batch_size=batch_size,
142+
feature_to_max_seqlen=self._fea_name_to_max_seqlen,
143+
contextual_feature_names=self._contextual_fea_names,
144+
item_feature_name=self._item_fea_name,
145+
action_feature_name=self._action_fea_name,
146+
max_num_candidates=self._max_num_candidates,
147+
num_candidates=num_candidates,
148+
).to(device=torch.cuda.current_device())
149+
self._cached_batch.append(
150+
tuple([batch, torch.tensor(user_ids).long(), total_history_lengths])
107151
)
108-
)
109-
user_ids = torch.tensor(user_ids).long()
110-
return user_ids if len(user_ids) > 0 else None
111-
112-
def get_random_inference_batch(
113-
self, user_ids, truncate_start_positions
114-
) -> Optional[HSTUBatch]:
115-
batch_size = len(user_ids)
116-
if batch_size == 0:
117-
return None
118-
user_ids = user_ids.tolist()
119-
item_hists = [
120-
self._item_history[uid] if uid in self._item_history else torch.tensor([])
121-
for uid in user_ids
122-
]
123-
action_hists = [
124-
self._action_history[uid]
125-
if uid in self._action_history
126-
else torch.tensor([])
127-
for uid in user_ids
128-
]
129-
130-
lengths = torch.tensor([len(hist_seq) for hist_seq in item_hists]).long()
131-
incr_lengths = torch.randint(
132-
low=1, high=self._max_incr_fea_len + 1, size=(batch_size,)
133-
)
134-
new_lengths = torch.clamp(lengths + incr_lengths, max=self._max_hist_len).long()
135-
incr_lengths = new_lengths - lengths
136-
137-
num_candidates = torch.randint(
138-
low=1, high=self._max_num_candidates + 1, size=(batch_size,)
139-
)
140-
if self._full_mode:
141-
incr_lengths = torch.full((batch_size,), self._max_incr_fea_len)
142-
new_lengths = torch.clamp(
143-
lengths + incr_lengths, max=self._max_hist_len
144-
).long()
145-
incr_lengths = new_lengths - lengths
146-
num_candidates = torch.full((batch_size,), self._max_num_candidates)
147-
148-
# Caveats: truncate_start_positions is for interleaved item-action sequence
149-
item_start_positions = (truncate_start_positions / 2).to(torch.int32)
150-
action_start_positions = (truncate_start_positions / 2).to(torch.int32)
151-
152-
item_seq = list()
153-
action_seq = list()
154-
for idx, uid in enumerate(user_ids):
155-
self._item_history[uid] = torch.cat(
156-
[
157-
item_hists[idx],
158-
torch.randint(self._max_item_id + 1, (incr_lengths[idx],)),
159-
],
160-
dim=0,
161-
).long()
162-
self._action_history[uid] = torch.cat(
163-
[
164-
action_hists[idx],
165-
torch.randint(self._max_action_id + 1, (incr_lengths[idx],)),
166-
],
167-
dim=0,
168-
).long()
169-
170-
item_history = torch.cat(
171-
[
172-
self._item_history[uid][item_start_positions[idx] :],
173-
torch.randint(self._max_item_id + 1, (num_candidates[idx].item(),)),
174-
],
175-
dim=0,
176-
)
177-
item_seq.append(item_history)
178-
action_seq.append(self._action_history[uid][action_start_positions[idx] :])
179-
180-
features = KeyedJaggedTensor.from_jt_dict(
181-
{
182-
self._item_fea_name: JaggedTensor.from_dense(item_seq),
183-
self._action_fea_name: JaggedTensor.from_dense(action_seq),
184-
}
185-
)
186-
187-
return HSTUBatch(
188-
features=features,
189-
batch_size=batch_size,
190-
feature_to_max_seqlen=self._fea_name_to_max_seqlen,
191-
contextual_feature_names=self._contextual_fea_names,
192-
item_feature_name=self._item_fea_name,
193-
action_feature_name=self._action_fea_name,
194-
max_num_candidates=self._max_num_candidates,
195-
num_candidates=num_candidates,
196-
)
152+
num_cached_batches += 1
153+
if num_cached_batches >= self._num_generated_batches:
154+
break
155+
156+
self._num_generated_batches = len(self._cached_batch)
157+
self._max_num_batches = self._num_generated_batches
158+
self._iloc = 0
159+
160+
def __iter__(self) -> Iterator[Tuple[HSTUBatch, torch.Tensor, torch.Tensor]]:
161+
"""
162+
Returns an iterator over the cached batches, cycling through them.
163+
164+
Returns:
165+
Tuple[HSTUBatch, torch.Tensor, torch.Tensor]: The next (batch, user_ids, history_lens) in the cycle.
166+
"""
167+
for _ in range(len(self)):
168+
yield self._cached_batch[self._iloc]
169+
self._iloc = (self._iloc + 1) % self._num_generated_batches
170+
171+
def __len__(self) -> int:
172+
"""
173+
Get the number of batches.
174+
175+
Returns:
176+
int: The number of batches.
177+
"""
178+
return self._max_num_batches

0 commit comments

Comments
 (0)