Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
- id: codespell
args:
- --skip=".git,corelib/hstu/*,third_party/*"
- --ignore-words-list=TE,TBE,tbe,dout,retrival,IndexT
- --ignore-words-list=TE,TBE,tbe,dout,retrival,IndexT,ANS,ans
- repo: https://github.com/psf/black
rev: 23.9.1
hooks:
Expand Down
3 changes: 2 additions & 1 deletion corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,8 @@ def load(
meta_json_file = encode_meta_json_file_path(save_dir, table_name)

if isinstance(storage, DynamicEmbeddingTable) and not storage._use_score:
dist.barrier() # sync global timestamp
if dist.is_initialized():
dist.barrier() # sync global timestamp
cast(DynamicEmbeddingTable, storage).update_timestamp()
num_key_files = len(emb_key_files)
for i in range(num_key_files):
Expand Down
17 changes: 8 additions & 9 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@ RUN ARCH=$([ "${TARGETPLATFORM}" = "linux/arm64" ] && echo "aarch64" || echo "x8
if [ ${ARCH} = "aarch64" ]; then \
ln -s /usr/local/cuda-12.9/targets/sbsa-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \
else \
if [[ "${INFERENCEBUILD}" == "1" && "${TRITONSERVER_BUILD}" != "1" ]]; then \
ln -s /usr/local/cuda-12.8/targets/${ARCH}-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \
else \
ln -s /usr/local/cuda-12.9/targets/${ARCH}-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \
fi \
ln -s /usr/local/cuda-12.9/targets/${ARCH}-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \
fi

RUN if [ "${INFERENCEBUILD}" != "1" ]; then \
Expand All @@ -50,6 +46,7 @@ RUN pip install --no-deps tensordict orjson && \

RUN pip install nvidia-cutlass-dsl==4.3.0


# for dev
RUN apt update -y --fix-missing && \
apt install -y gdb && \
Expand Down Expand Up @@ -81,9 +78,11 @@ COPY . .
RUN cd /workspace/recsys-examples/corelib/dynamicemb && \
python setup.py install

RUN if [ "${TRITONSERVER_BUILD}" = "1" ]; then \
pip3 uninstall -y hstu_attn hstu_hopper; \
fi

RUN cd /workspace/deps && rm -rf nvcomp && \
wget https://developer.download.nvidia.com/compute/nvcomp/redist/nvcomp/linux-x86_64/nvcomp-linux-x86_64-5.1.0.21_cuda12-archive.tar.xz && \
tar -xf nvcomp-linux-x86_64-5.1.0.21_cuda12-archive.tar.xz && \
mv nvcomp-linux-x86_64-5.1.0.21_cuda12-archive nvcomp && \
rm nvcomp-linux-x86_64-5.1.0.21_cuda12-archive.tar.xz

RUN cd /workspace/recsys-examples/examples/commons && \
TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0" python3 setup.py install
25 changes: 19 additions & 6 deletions examples/commons/datasets/inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,25 @@ def __iter__(self) -> Iterator[HSTUBatch]:
)
dates.append(self._batch_logs_frame.iloc[sample_id][self._date_name])
seq_endptrs.append(seq_endptr)
if len(user_ids) == 0:
continue

last_date = dates[0]
final_user_ids: List[int] = []
final_dates: List[int] = []
final_seq_endptrs: List[int] = []
for uid, date, endp in zip(user_ids, dates, seq_endptrs):
if date != last_date:
continue
if uid not in final_user_ids:
final_user_ids.append(uid)
final_dates.append(date)
final_seq_endptrs.append(endp)
else:
idx = final_user_ids.index(uid)
final_seq_endptrs[idx] = max(final_seq_endptrs[idx], endp)
yield (
torch.tensor(user_ids),
torch.tensor(dates),
torch.tensor(seq_endptrs),
torch.tensor(final_user_ids),
torch.tensor(final_dates),
torch.tensor(final_seq_endptrs),
)

def get_input_batch(
Expand Down Expand Up @@ -306,7 +319,7 @@ def get_input_batch(
labels = torch.tensor(labels, dtype=torch.int64, device=self._device)
batch_kwargs = dict(
features=features,
batch_size=self._batch_size,
batch_size=len(user_ids),
feature_to_max_seqlen=feature_to_max_seqlen,
contextual_feature_names=self._contextual_feature_names,
item_feature_name=self._item_feature_name,
Expand Down
224 changes: 103 additions & 121 deletions examples/commons/datasets/random_inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import Dict, List, Optional
from typing import Dict, Iterator, List, Tuple

import torch
from torch.utils.data.dataset import IterableDataset
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor

from .hstu_batch import FeatureConfig, HSTUBatch


class RandomInferenceDataGenerator:
class RandomInferenceDataset(
IterableDataset[Tuple[HSTUBatch, torch.Tensor, torch.Tensor]]
):
"""
A random generator for the inference batches

Expand All @@ -32,12 +34,12 @@ class RandomInferenceDataGenerator:
action_feature_name (str): The action feature name.
max_num_users (int): The maximum user numbers.
max_batch_size (int): The maximum batch size.
max_seqlen (int): The maximum sequence length (with candidates) for item
in request per user. The length of action sequence in
request the same with that of HISTORY item sequence.
max_history_length (int): The maximum history length for item in request per user.
The length of action sequence in request is the same.
max_num_candidates (int): The maximum candidates number.
max_incremental_seqlen (int): The maximum incremental length of HISTORY
item AND action sequence.
max_num_cached_batches (int, optional): The number of batches to generate. Defaults to 1.
full_mode (bool): The flag for full batch mode.
"""

Expand All @@ -49,9 +51,10 @@ def __init__(
action_feature_name: str = "",
max_num_users: int = 1,
max_batch_size: int = 32,
max_seqlen: int = 4096,
max_history_length: int = 4096,
max_num_candidates: int = 200,
max_incremental_seqlen: int = 64,
max_num_cached_batches: int = 1,
full_mode: bool = False,
):
super().__init__()
Expand All @@ -72,125 +75,104 @@ def __init__(

self._max_num_users = min(max_num_users, 2**16)
self._max_batch_size = max_batch_size
self._max_hist_len = max_seqlen - max_num_candidates
self._max_incr_fea_len = max(max_incremental_seqlen, 1)
self._max_hist_len = max_history_length
self._max_num_candidates = max_num_candidates
self._max_incr_fea_len = max(max_incremental_seqlen, 1)
self._num_generated_batches = max(max_num_cached_batches, 1)

self._full_mode = full_mode

self._item_history: Dict[int, torch.Tensor] = dict()
self._action_history: Dict[int, torch.Tensor] = dict()

def get_inference_batch_user_ids(self) -> Optional[torch.Tensor]:
if self._full_mode:
batch_size = self._max_batch_size
user_ids = list(range(self._max_batch_size))
else:
batch_size = random.randint(1, self._max_batch_size)
user_ids = torch.randint(self._max_num_users, (batch_size,)).tolist()
user_ids = list(set(user_ids))

user_ids = torch.tensor(
[
uid
for uid in user_ids
if uid not in self._item_history
or len(self._item_history[uid]) < self._max_hist_len
]
).long()
if self._full_mode and len(user_ids) == 0:
batch_size = self._max_batch_size
user_ids = list(
range(
self._max_batch_size,
min(self._max_batch_size * 2, self._max_num_users),
num_cached_batches = 0
self._cached_batch = list()
for seqlen_idx in range(
max_incremental_seqlen, self._max_hist_len, max_incremental_seqlen
):
for idx in range(0, self._max_num_users, self._max_batch_size):
if self._full_mode:
user_ids = list(
range(idx, min(self._max_num_users, idx + self._max_batch_size))
)
else:
user_ids = torch.randint(
self._max_num_users, (self._max_batch_size,)
).tolist()
user_ids = list(set(user_ids))

batch_size = len(user_ids)

item_seq = list()
action_seq = list()
for uid in user_ids:
if uid not in self._item_history or uid not in self._action_history:
self._item_history[uid] = torch.randint(
self._max_item_id + 1,
(self._max_hist_len + self._max_num_candidates,),
)
self._action_history[uid] = torch.randint(
self._max_action_id + 1,
(self._max_hist_len + self._max_num_candidates,),
)

item_seq.append(
self._item_history[uid][: seqlen_idx + self._max_num_candidates]
)
action_seq.append(self._action_history[uid][:seqlen_idx])
features = KeyedJaggedTensor.from_jt_dict(
{
self._item_fea_name: JaggedTensor.from_dense(item_seq),
self._action_fea_name: JaggedTensor.from_dense(action_seq),
}
)

if self._full_mode:
num_candidates = torch.full((batch_size,), self._max_num_candidates)
else:
num_candidates = torch.randint(
low=1, high=self._max_num_candidates + 1, size=(batch_size,)
)

total_history_lengths = torch.full((batch_size,), seqlen_idx * 2)

batch = HSTUBatch(
features=features,
batch_size=batch_size,
feature_to_max_seqlen=self._fea_name_to_max_seqlen,
contextual_feature_names=self._contextual_fea_names,
item_feature_name=self._item_fea_name,
action_feature_name=self._action_fea_name,
max_num_candidates=self._max_num_candidates,
num_candidates=num_candidates,
).to(device=torch.cuda.current_device())
self._cached_batch.append(
tuple([batch, torch.tensor(user_ids).long(), total_history_lengths])
)
)
user_ids = torch.tensor(user_ids).long()
return user_ids if len(user_ids) > 0 else None

def get_random_inference_batch(
self, user_ids, truncate_start_positions
) -> Optional[HSTUBatch]:
batch_size = len(user_ids)
if batch_size == 0:
return None
user_ids = user_ids.tolist()
item_hists = [
self._item_history[uid] if uid in self._item_history else torch.tensor([])
for uid in user_ids
]
action_hists = [
self._action_history[uid]
if uid in self._action_history
else torch.tensor([])
for uid in user_ids
]

lengths = torch.tensor([len(hist_seq) for hist_seq in item_hists]).long()
incr_lengths = torch.randint(
low=1, high=self._max_incr_fea_len + 1, size=(batch_size,)
)
new_lengths = torch.clamp(lengths + incr_lengths, max=self._max_hist_len).long()
incr_lengths = new_lengths - lengths

num_candidates = torch.randint(
low=1, high=self._max_num_candidates + 1, size=(batch_size,)
)
if self._full_mode:
incr_lengths = torch.full((batch_size,), self._max_incr_fea_len)
new_lengths = torch.clamp(
lengths + incr_lengths, max=self._max_hist_len
).long()
incr_lengths = new_lengths - lengths
num_candidates = torch.full((batch_size,), self._max_num_candidates)

# Caveats: truncate_start_positions is for interleaved item-action sequence
item_start_positions = (truncate_start_positions / 2).to(torch.int32)
action_start_positions = (truncate_start_positions / 2).to(torch.int32)

item_seq = list()
action_seq = list()
for idx, uid in enumerate(user_ids):
self._item_history[uid] = torch.cat(
[
item_hists[idx],
torch.randint(self._max_item_id + 1, (incr_lengths[idx],)),
],
dim=0,
).long()
self._action_history[uid] = torch.cat(
[
action_hists[idx],
torch.randint(self._max_action_id + 1, (incr_lengths[idx],)),
],
dim=0,
).long()

item_history = torch.cat(
[
self._item_history[uid][item_start_positions[idx] :],
torch.randint(self._max_item_id + 1, (num_candidates[idx].item(),)),
],
dim=0,
)
item_seq.append(item_history)
action_seq.append(self._action_history[uid][action_start_positions[idx] :])

features = KeyedJaggedTensor.from_jt_dict(
{
self._item_fea_name: JaggedTensor.from_dense(item_seq),
self._action_fea_name: JaggedTensor.from_dense(action_seq),
}
)

return HSTUBatch(
features=features,
batch_size=batch_size,
feature_to_max_seqlen=self._fea_name_to_max_seqlen,
contextual_feature_names=self._contextual_fea_names,
item_feature_name=self._item_fea_name,
action_feature_name=self._action_fea_name,
max_num_candidates=self._max_num_candidates,
num_candidates=num_candidates,
)
num_cached_batches += 1
if num_cached_batches >= self._num_generated_batches:
break

self._num_generated_batches = len(self._cached_batch)
self._max_num_batches = self._num_generated_batches
self._iloc = 0

def __iter__(self) -> Iterator[Tuple[HSTUBatch, torch.Tensor, torch.Tensor]]:
"""
Returns an iterator over the cached batches, cycling through them.

Returns:
Tuple[HSTUBatch, torch.Tensor, torch.Tensor]: The next (batch, user_ids, history_lens) in the cycle.
"""
for _ in range(len(self)):
yield self._cached_batch[self._iloc]
self._iloc = (self._iloc + 1) % self._num_generated_batches

def __len__(self) -> int:
"""
Get the number of batches.

Returns:
int: The number of batches.
"""
return self._max_num_batches
Loading