Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 183 additions & 7 deletions examples/commons/datasets/hstu_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,196 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from dataclasses import dataclass
from typing import List, Optional
from enum import Enum
from typing import Dict, List, Optional

import gin
import numpy as np
import torch
from commons.sequence_batch.batch import BaseBatch
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


class DistType(str, Enum):
"""Supported random distribution types.

Only used for benchmark / random data generation (see :class:`RandomDistribution`).
"""

UNIFORM = "uniform"
NORMAL = "normal"
ZIPF = "zipf"
LOGNORMAL = "lognormal"


@gin.configurable
@dataclass
class RandomDistribution:
"""
A configurable random distribution for generating non-negative integer (natural number) samples.

.. note::
**Benchmark only** — This class is designed for synthetic / random data generation
in benchmarking and testing scenarios. It is **not** used when training with real
datasets (e.g. MovieLens, KuaiRand).

All samples are natural numbers (>= 0). The default lower bound is 0, and the default
upper bound is unbounded (None means no upper clamp).

Supports four distribution types:
- **uniform**: Samples uniformly from [low, high). ``high`` is required for uniform.
- **normal**: Samples from N(mean, std), then rounds and clamps to [low, high].
- **zipf**: Samples from Zipf(alpha) with P(k) ∝ k^{-alpha} (k >= 1),
shifted to start from ``low``, then clamps to [low, high].
- **lognormal**: Samples from LogNormal(mean, std), then rounds and clamps to
[low, high]. Useful for modeling sequence lengths which are typically
right-skewed (many short sequences, few long ones). For lognormal, the [M/r,M⋅r]
interval ratio only depends on r, and is independent of M.

Args:
dist_type: One of ``DistType.UNIFORM``, ``DistType.NORMAL``, ``DistType.ZIPF``,
``DistType.LOGNORMAL``.
low: Inclusive lower bound. Default: 0.
high: Optional exclusive upper bound for uniform, or inclusive upper clamp for
normal/zipf. Default: None (no upper bound, except uniform which requires it).
mean: Mean parameter. For normal this is the distribution mean; for lognormal
this is the **actual (real-space) mean** E[X], *not* the underlying μ.
Default: None (auto-inferred).
std: Std parameter. For normal this is the distribution std; for lognormal
this is the **actual (real-space) standard deviation** SD[X], *not* the
underlying σ. The underlying lognormal parameters are derived as:
``σ² = ln(1 + (std/mean)²)``, ``μ = ln(mean) - σ²/2``.
Default: None. For lognormal, defaults to ``mean / 2`` (CV = 0.5,
~80% of samples within [0.49M, 1.64M]).
alpha: Shape parameter for Zipf distribution (must be > 1.0). Default: 1.5.

Example:
>>> # Zipf with no upper limit
>>> dist = RandomDistribution(DistType.ZIPF, alpha=1.2)
>>> samples = dist.sample(size=128, device=torch.device("cpu"))
>>> # Normal clamped to [10, 500]
>>> dist = RandomDistribution(DistType.NORMAL, low=10, high=500, mean=100, std=50)
>>> # Lognormal: actual mean=512, actual std=256 (80% within ~[256, 1024])
>>> dist = RandomDistribution(DistType.LOGNORMAL, mean=512, std=256, high=2048)
"""

dist_type: DistType
low: int = 0
high: Optional[int] = None
# normal distribution parameters
mean: Optional[float] = None
std: Optional[float] = None
# zipf distribution parameter
alpha: Optional[float] = None

def sample(
self,
size: int,
device: torch.device,
) -> torch.Tensor:
"""
Generate ``size`` non-negative integer samples from the configured distribution.

Args:
size: Number of samples to generate.
device: Target device for the returned tensor.

Returns:
A 1-D ``torch.Tensor`` of shape ``(size,)`` with dtype ``torch.long``.
"""
lo = self.low
hi = self.high # None means no upper bound

if self.dist_type == DistType.UNIFORM:
assert hi is not None, "uniform distribution requires `high` to be set"
assert hi > lo, f"uniform requires high > low, got [{lo}, {hi})"
return torch.randint(lo, hi, (size,), device=device)

elif self.dist_type == DistType.NORMAL:
assert (
self.mean is not None and self.std is not None
), "normal distribution requires `mean` and `std` to be set"
assert self.std > 0, f"normal requires std > 0, got {self.std}"
samples = torch.normal(self.mean, self.std, (size,))
samples = samples.clamp(min=lo)
if hi is not None:
samples = samples.clamp(max=hi - 1)
return samples.round().long().to(device)

elif self.dist_type == DistType.LOGNORMAL:
assert (
self.mean is not None
), "lognormal distribution requires `mean` to be set"
assert self.mean > 0, f"lognormal requires mean > 0, got {self.mean}"
# Default CV = 0.5 (std = mean / 2) when std is not specified
actual_std = self.std if self.std is not None else self.mean / 2.0
assert actual_std > 0, f"lognormal requires std > 0, got {actual_std}"
# User specifies actual (real-space) mean M and std S.
# Convert to underlying normal parameters:
# σ² = ln(1 + (S/M)²)
# μ = ln(M) - σ²/2
cv_sq = (actual_std / self.mean) ** 2
sigma_sq = math.log(1.0 + cv_sq)
mu = math.log(self.mean) - sigma_sq / 2.0
sigma = math.sqrt(sigma_sq)
samples = torch.distributions.LogNormal(mu, sigma).sample((size,))
samples = samples.clamp(min=lo)
if hi is not None:
samples = samples.clamp(max=hi - 1)
return samples.round().long().to(device)

elif self.dist_type == DistType.ZIPF:
alpha = self.alpha if self.alpha is not None else 1.5
assert alpha > 1.0, f"zipf requires alpha > 1.0, got {alpha}"
# numpy zipf generates integers >= 1, shift by (lo - 1) so minimum is lo
raw = np.random.zipf(alpha, size=size)
samples = torch.from_numpy(raw).long() + (lo - 1)
samples = samples.clamp(min=lo)
if hi is not None:
samples = samples.clamp(max=hi - 1)
return samples.to(device)

else:
raise ValueError(f"Unknown distribution type: {self.dist_type}")


@dataclass
class FeatureConfig:
"""
Configuration for features in a dataset. A FeatureConfig is a collection of features that share the same seqlen (also the same max_seqlence_length).
For example, in HSTU based models, an item is always associated with a timestamp token.
Configuration for features in a dataset.

.. note::
**Benchmark / test only** — ``FeatureConfig`` is consumed by
:meth:`HSTUBatch.random` and :class:`HSTURandomDataset` to generate synthetic
data. It is **not** used when training with real datasets (e.g. MovieLens,
KuaiRand). In the gin configuration layer the corresponding entry point is
:class:`FeatureArgs` (inside ``BenchmarkDatasetArgs``).

A ``FeatureConfig`` groups features that share the same sequence length (and the
same ``max_sequence_length``). For example, in HSTU-based models an item feature
is always paired with a timestamp token — both share one ``FeatureConfig``.

Attributes:
feature_names (List[str]): List of names for the features.
max_item_ids (List[int]): List of maximum item IDs for each feature.
max_sequence_length (int): The maximum length of sequences in the dataset.
is_jagged (bool): Whether the sequences are jagged (i.e., have varying lengths).
seqlen_dist (Optional[RandomDistribution]): Distribution for generating sequence lengths.
Only used when ``is_jagged=True``. If None, defaults to uniform [0, max_sequence_length).
value_dists (Optional[Dict[str, RandomDistribution]]): Per-feature distributions for
generating values, keyed by feature name. Features not present in the dict fall back
to uniform [0, max_item_id). If None, all features use the default uniform distribution.
"""

feature_names: List[str]
max_item_ids: List[int]
max_sequence_length: int
is_jagged: bool
seqlen_dist: Optional[RandomDistribution] = None
value_dists: Optional[Dict[str, RandomDistribution]] = None


@dataclass
Expand Down Expand Up @@ -132,9 +296,15 @@ def random(
for fc in feature_configs:
# Generate data for actual_batch_size samples
if fc.is_jagged:
seqlen = torch.randint(
fc.max_sequence_length, (actual_batch_size,), device=device
)
if fc.seqlen_dist is not None:
seqlen = fc.seqlen_dist.sample(
size=actual_batch_size,
device=device,
)
else:
seqlen = torch.randint(
fc.max_sequence_length, (actual_batch_size,), device=device
)
else:
seqlen = torch.full(
(actual_batch_size,), fc.max_sequence_length, device=device
Expand All @@ -154,7 +324,13 @@ def random(
for feature_name, max_item_id in zip(fc.feature_names, fc.max_item_ids):
if feature_name in contextual_feature_names and fc.is_jagged:
warnings.warn(f"contextual feature {feature_name} is jagged")
value = torch.randint(max_item_id, (cur_seqlen_sum,), device=device)
if fc.value_dists is not None and feature_name in fc.value_dists:
value = fc.value_dists[feature_name].sample(
size=cur_seqlen_sum,
device=device,
)
else:
value = torch.randint(max_item_id, (cur_seqlen_sum,), device=device)
keys.append(feature_name)
values.append(value)
lengths.append(seqlen)
Expand Down
9 changes: 8 additions & 1 deletion examples/commons/datasets/hstu_random_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@

class HSTURandomDataset(IterableDataset[HSTUBatch]):
"""
A dummy sequence dataset for generating batches of data.
A synthetic (random) dataset for benchmark and testing purposes.

.. note::
**Benchmark / test only** — This dataset generates random batches via
:meth:`HSTUBatch.random` and is **not** used when training with real
datasets (e.g. MovieLens, KuaiRand). It is instantiated automatically
when :class:`~utils.gin_config_args.BenchmarkDatasetArgs` is provided
as the dataset configuration.

Args:
batch_size (int): The batchsize per rank.
Expand Down
41 changes: 33 additions & 8 deletions examples/commons/distributed/batch_allgather.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Union
from dataclasses import fields
from typing import Dict, List, Union

import torch
from commons.ops.collective_ops import (
gather_along_first_dim,
gatherv_along_first_dim,
keyed_jagged_tensor_allgather,
keyed_jagged_tensor_list_allgather,
)
from commons.sequence_batch.batch import BaseBatch
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
Expand All @@ -17,6 +18,10 @@ def allgather_batch(
"""
Allgather the batch across the process group.
Note we will update the feature_to_max_seqlen in the batch.

All KJT fields are fused into a **single** AllGather call pair
(1 for lengths, 1 for values) via :func:`keyed_jagged_tensor_list_allgather`.
Dense tensor fields are gathered separately.
"""

# TODO@junzhang, we should avoid coping with jagged padding...
Expand All @@ -31,20 +36,40 @@ def allgather_batch(
world_size = torch.distributed.get_world_size(pg_group)
global_batch_size = batch.batch_size * world_size

def allgather_tensor_or_kjt(tensor_or_kjt: Union[torch.Tensor, KeyedJaggedTensor]):
if isinstance(tensor_or_kjt, torch.Tensor):
# ---- Phase 1: collect KJT fields and fused AllGather them ----
kjt_field_names: List[str] = []
kjt_inputs: List[KeyedJaggedTensor] = []
for f in fields(batch):
val = getattr(batch, f.name)
if isinstance(val, KeyedJaggedTensor):
kjt_field_names.append(f.name)
kjt_inputs.append(val)

kjt_outputs = keyed_jagged_tensor_list_allgather(kjt_inputs, pg_group)
kjt_result_map: Dict[str, KeyedJaggedTensor] = dict(
zip(kjt_field_names, kjt_outputs)
)

# ---- Phase 2: gather dense tensors one by one ----
def allgather_field(tensor_or_kjt: Union[torch.Tensor, KeyedJaggedTensor]):
if isinstance(tensor_or_kjt, KeyedJaggedTensor):
# Already gathered in Phase 1 – will be patched in below.
return tensor_or_kjt # placeholder, replaced after
elif isinstance(tensor_or_kjt, torch.Tensor):
if actual_batch_size != global_batch_size:
ag_object = gatherv_along_first_dim(tensor_or_kjt, pg_group)
else:
ag_object = gather_along_first_dim(tensor_or_kjt, pg_group)
return ag_object
elif isinstance(tensor_or_kjt, KeyedJaggedTensor):
kjt_out = keyed_jagged_tensor_allgather(tensor_or_kjt, pg_group)
return kjt_out
else:
raise ValueError(f"Unsupported type: {type(tensor_or_kjt)}")

new_batch = batch._apply_to_tensors_or_kjt(allgather_tensor_or_kjt, inplace=False)
new_batch = batch._apply_to_tensors_or_kjt(allgather_field, inplace=False)

# Patch the KJT fields with the fused results
for name, kjt_out in kjt_result_map.items():
setattr(new_batch, name, kjt_out)

new_batch.batch_size = new_batch.batch_size * world_size
# this will block host until all processes have finished the allreduce.
new_batch.actual_batch_size = actual_batch_size.item()
Expand Down
Loading