Skip to content

Commit ab48756

Browse files
committed
Finalized async kvcache manager implementation
1 parent 30e0cca commit ab48756

22 files changed

+2750
-723
lines changed

examples/commons/datasets/inference_dataset.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,25 @@ def __iter__(self) -> Iterator[HSTUBatch]:
160160
)
161161
dates.append(self._batch_logs_frame.iloc[sample_id][self._date_name])
162162
seq_endptrs.append(seq_endptr)
163-
if len(user_ids) == 0:
164-
continue
163+
164+
last_date = dates[0]
165+
final_user_ids: List[int] = []
166+
final_dates: List[int] = []
167+
final_seq_endptrs: List[int] = []
168+
for (uid, date, endp) in zip(user_ids, dates, seq_endptrs):
169+
if date != last_date:
170+
continue
171+
if uid not in final_user_ids:
172+
final_user_ids.append(uid)
173+
final_dates.append(date)
174+
final_seq_endptrs.append(endp)
175+
else:
176+
idx = final_user_ids.index(uid)
177+
final_seq_endptrs[idx] = max(final_seq_endptrs[idx], endp)
165178
yield (
166-
torch.tensor(user_ids),
167-
torch.tensor(dates),
168-
torch.tensor(seq_endptrs),
179+
torch.tensor(final_user_ids),
180+
torch.tensor(final_dates),
181+
torch.tensor(final_seq_endptrs),
169182
)
170183

171184
def get_input_batch(
@@ -306,7 +319,7 @@ def get_input_batch(
306319
labels = torch.tensor(labels, dtype=torch.int64, device=self._device)
307320
batch_kwargs = dict(
308321
features=features,
309-
batch_size=self._batch_size,
322+
batch_size=len(user_ids), # self._batch_size,
310323
feature_to_max_seqlen=feature_to_max_seqlen,
311324
contextual_feature_names=self._contextual_feature_names,
312325
item_feature_name=self._item_feature_name,

examples/commons/datasets/random_inference_dataset.py

Lines changed: 85 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import random
16-
from typing import Dict, List, Optional
16+
from typing import Dict, Iterator, List, Optional, Tuple
1717

1818
import torch
1919
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
20+
from torch.utils.data.dataset import IterableDataset
2021

2122
from .hstu_batch import FeatureConfig, HSTUBatch
2223

2324

24-
class RandomInferenceDataGenerator:
25+
class RandomInferenceDataset(IterableDataset[Tuple[HSTUBatch, torch.Tensor, torch.Tensor]]):
2526
"""
2627
A random generator for the inference batches
2728
@@ -32,12 +33,12 @@ class RandomInferenceDataGenerator:
3233
action_feature_name (str): The action feature name.
3334
max_num_users (int): The maximum user numbers.
3435
max_batch_size (int): The maximum batch size.
35-
max_seqlen (int): The maximum sequence length (with candidates) for item
36-
in request per user. The length of action sequence in
37-
request the same with that of HISTORY item sequence.
36+
max_history_length (int): The maximum history length for item in request per user.
37+
The length of action sequence in request is the same.
3838
max_num_candidates (int): The maximum candidates number.
3939
max_incremental_seqlen (int): The maximum incremental length of HISTORY
4040
item AND action sequence.
41+
max_num_cached_batches (int, optional): The number of batches to generate. Defaults to 1.
4142
full_mode (bool): The flag for full batch mode.
4243
"""
4344

@@ -49,9 +50,10 @@ def __init__(
4950
action_feature_name: str = "",
5051
max_num_users: int = 1,
5152
max_batch_size: int = 32,
52-
max_seqlen: int = 4096,
53+
max_history_length: int = 4096,
5354
max_num_candidates: int = 200,
5455
max_incremental_seqlen: int = 64,
56+
max_num_cached_batches: int = 1,
5557
full_mode: bool = False,
5658
):
5759
super().__init__()
@@ -72,125 +74,88 @@ def __init__(
7274

7375
self._max_num_users = min(max_num_users, 2**16)
7476
self._max_batch_size = max_batch_size
75-
self._max_hist_len = max_seqlen - max_num_candidates
76-
self._max_incr_fea_len = max(max_incremental_seqlen, 1)
77+
self._max_hist_len = max_history_length
7778
self._max_num_candidates = max_num_candidates
79+
self._max_incr_fea_len = max(max_incremental_seqlen, 1)
80+
self._num_generated_batches = max(max_num_cached_batches, 1)
7881

7982
self._full_mode = full_mode
8083

8184
self._item_history: Dict[int, torch.Tensor] = dict()
8285
self._action_history: Dict[int, torch.Tensor] = dict()
8386

84-
def get_inference_batch_user_ids(self) -> Optional[torch.Tensor]:
85-
if self._full_mode:
86-
batch_size = self._max_batch_size
87-
user_ids = list(range(self._max_batch_size))
88-
else:
89-
batch_size = random.randint(1, self._max_batch_size)
90-
user_ids = torch.randint(self._max_num_users, (batch_size,)).tolist()
91-
user_ids = list(set(user_ids))
92-
93-
user_ids = torch.tensor(
94-
[
95-
uid
96-
for uid in user_ids
97-
if uid not in self._item_history
98-
or len(self._item_history[uid]) < self._max_hist_len
99-
]
100-
).long()
101-
if self._full_mode and len(user_ids) == 0:
102-
batch_size = self._max_batch_size
103-
user_ids = list(
104-
range(
105-
self._max_batch_size,
106-
min(self._max_batch_size * 2, self._max_num_users),
87+
num_cached_batches = 0
88+
self._cached_batch = list()
89+
for seqlen_idx in range(max_incremental_seqlen, self._max_hist_len, max_incremental_seqlen):
90+
for idx in range(0, self._max_num_users, self._max_batch_size):
91+
if self._full_mode:
92+
user_ids = list(range(idx, min(self._max_num_users, idx + self._max_batch_size)))
93+
else:
94+
user_ids = torch.randint(self._max_num_users, (batch_size,)).tolist()
95+
user_ids = list(set(user_ids))
96+
97+
batch_size = len(user_ids)
98+
99+
item_seq = list()
100+
action_seq = list()
101+
for uid in user_ids:
102+
if uid not in self._item_history or uid not in self._action_history:
103+
self._item_history[uid] = torch.randint(self._max_item_id + 1, (self._max_hist_len + self._max_num_candidates,))
104+
self._action_history[uid] = torch.randint(self._max_action_id + 1, (self._max_hist_len + self._max_num_candidates,))
105+
106+
item_seq.append(self._item_history[uid][:seqlen_idx + self._max_num_candidates])
107+
action_seq.append(self._action_history[uid][:seqlen_idx])
108+
features = KeyedJaggedTensor.from_jt_dict(
109+
{
110+
self._item_fea_name: JaggedTensor.from_dense(item_seq),
111+
self._action_fea_name: JaggedTensor.from_dense(action_seq),
112+
}
107113
)
108-
)
109-
user_ids = torch.tensor(user_ids).long()
110-
return user_ids if len(user_ids) > 0 else None
111-
112-
def get_random_inference_batch(
113-
self, user_ids, truncate_start_positions
114-
) -> Optional[HSTUBatch]:
115-
batch_size = len(user_ids)
116-
if batch_size == 0:
117-
return None
118-
user_ids = user_ids.tolist()
119-
item_hists = [
120-
self._item_history[uid] if uid in self._item_history else torch.tensor([])
121-
for uid in user_ids
122-
]
123-
action_hists = [
124-
self._action_history[uid]
125-
if uid in self._action_history
126-
else torch.tensor([])
127-
for uid in user_ids
128-
]
129-
130-
lengths = torch.tensor([len(hist_seq) for hist_seq in item_hists]).long()
131-
incr_lengths = torch.randint(
132-
low=1, high=self._max_incr_fea_len + 1, size=(batch_size,)
133-
)
134-
new_lengths = torch.clamp(lengths + incr_lengths, max=self._max_hist_len).long()
135-
incr_lengths = new_lengths - lengths
136-
137-
num_candidates = torch.randint(
138-
low=1, high=self._max_num_candidates + 1, size=(batch_size,)
139-
)
140-
if self._full_mode:
141-
incr_lengths = torch.full((batch_size,), self._max_incr_fea_len)
142-
new_lengths = torch.clamp(
143-
lengths + incr_lengths, max=self._max_hist_len
144-
).long()
145-
incr_lengths = new_lengths - lengths
146-
num_candidates = torch.full((batch_size,), self._max_num_candidates)
147-
148-
# Caveats: truncate_start_positions is for interleaved item-action sequence
149-
item_start_positions = (truncate_start_positions / 2).to(torch.int32)
150-
action_start_positions = (truncate_start_positions / 2).to(torch.int32)
151-
152-
item_seq = list()
153-
action_seq = list()
154-
for idx, uid in enumerate(user_ids):
155-
self._item_history[uid] = torch.cat(
156-
[
157-
item_hists[idx],
158-
torch.randint(self._max_item_id + 1, (incr_lengths[idx],)),
159-
],
160-
dim=0,
161-
).long()
162-
self._action_history[uid] = torch.cat(
163-
[
164-
action_hists[idx],
165-
torch.randint(self._max_action_id + 1, (incr_lengths[idx],)),
166-
],
167-
dim=0,
168-
).long()
169-
170-
item_history = torch.cat(
171-
[
172-
self._item_history[uid][item_start_positions[idx] :],
173-
torch.randint(self._max_item_id + 1, (num_candidates[idx].item(),)),
174-
],
175-
dim=0,
176-
)
177-
item_seq.append(item_history)
178-
action_seq.append(self._action_history[uid][action_start_positions[idx] :])
179-
180-
features = KeyedJaggedTensor.from_jt_dict(
181-
{
182-
self._item_fea_name: JaggedTensor.from_dense(item_seq),
183-
self._action_fea_name: JaggedTensor.from_dense(action_seq),
184-
}
185-
)
186-
187-
return HSTUBatch(
188-
features=features,
189-
batch_size=batch_size,
190-
feature_to_max_seqlen=self._fea_name_to_max_seqlen,
191-
contextual_feature_names=self._contextual_fea_names,
192-
item_feature_name=self._item_fea_name,
193-
action_feature_name=self._action_fea_name,
194-
max_num_candidates=self._max_num_candidates,
195-
num_candidates=num_candidates,
196-
)
114+
115+
if self._full_mode:
116+
num_candidates = torch.full((batch_size,), self._max_num_candidates)
117+
else:
118+
num_candidates = torch.randint(
119+
low=1, high=self._max_num_candidates + 1, size=(batch_size,)
120+
)
121+
122+
total_history_lengths = torch.full((batch_size,), seqlen_idx * 2)
123+
124+
batch = HSTUBatch(
125+
features=features,
126+
batch_size=batch_size,
127+
feature_to_max_seqlen=self._fea_name_to_max_seqlen,
128+
contextual_feature_names=self._contextual_fea_names,
129+
item_feature_name=self._item_fea_name,
130+
action_feature_name=self._action_fea_name,
131+
max_num_candidates=self._max_num_candidates,
132+
num_candidates=num_candidates,
133+
).to(device=torch.cuda.current_device())
134+
self._cached_batch.append(tuple([batch, torch.tensor(user_ids).long(), total_history_lengths]))
135+
num_cached_batches += 1
136+
if num_cached_batches >= self._num_generated_batches:
137+
break
138+
139+
self._num_generated_batches = len(self._cached_batch)
140+
self._max_num_batches = self._num_generated_batches
141+
self._iloc = 0
142+
143+
def __iter__(self) -> Iterator[Tuple[HSTUBatch, torch.Tensor, torch.Tensor]]:
144+
"""
145+
Returns an iterator over the cached batches, cycling through them.
146+
147+
Returns:
148+
Tuple[HSTUBatch, torch.Tensor, torch.Tensor]: The next (batch, user_ids, history_lens) in the cycle.
149+
"""
150+
for _ in range(len(self)):
151+
yield self._cached_batch[self._iloc]
152+
self._iloc = (self._iloc + 1) % self._num_generated_batches
153+
154+
def __len__(self) -> int:
155+
"""
156+
Get the number of batches.
157+
158+
Returns:
159+
int: The number of batches.
160+
"""
161+
return self._max_num_batches

0 commit comments

Comments
 (0)