Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- General purpose patching API for patch-based diffusion
- New positional embedding selection strategy for CorrDiff SongUNet models
- Added Multi-Storage Client to allow checkpointing to/from Object Storage
- InfiniteHashSampler a memory-efficient sampler for very large datasets

### Changed

Expand Down
101 changes: 101 additions & 0 deletions physicsnemo/utils/generative/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,107 @@ def __iter__(self) -> Iterator[int]:
idx += 1


class InfiniteHashSampler(torch.utils.data.Sampler[int]): # pragma: no cover
"""
Memory-efficient infinite sampler for very large datasets that uses hash-based
randomization without storing full index arrays.

This sampler is designed for billion-scale datasets where storing np.arange(len(dataset))
would consume too much memory. It provides:
- O(1) memory usage regardless of dataset size
- Deterministic hash-based randomization (when enabled)
- Full support for DistributedDataParallel
- Sequential or pseudo-random access patterns

Parameters
----------
dataset : torch.utils.data.Dataset
The dataset to sample from
rank : int, default=0
The rank of the current process within num_replicas processes
num_replicas : int, default=1
The number of processes participating in distributed sampling
randomize : bool, default=True
Whether to use hash-based randomization of indices
seed : int, default=0
Random seed for deterministic hash-based randomization
"""

def __init__(
self,
dataset: torch.utils.data.Dataset,
rank: int = 0,
num_replicas: int = 1,
randomize: bool = True,
seed: int = 0,
):
if not len(dataset) > 0:
raise ValueError("Dataset must contain at least one item")
if not num_replicas > 0:
raise ValueError("num_replicas must be positive")
if not 0 <= rank < num_replicas:
raise ValueError("rank must be non-negative and less than num_replicas")
super().__init__()
self.dataset = dataset
self.rank = rank
self.num_replicas = num_replicas
self.randomize = randomize
self.seed = seed
self.dataset_size = len(dataset)

def _hash_index(self, sequential_idx: int) -> int:
"""
Apply hash-based randomization to a sequential index.

Uses a simple but effective hash function that mixes the sequential index
with the seed and rank to produce well-distributed pseudo-random indices.
The rank is incorporated to ensure different ranks get different sequences.

Parameters
----------
sequential_idx : int
Sequential index to randomize

Returns
-------
int
Randomized index in range [0, dataset_size)
"""
# Mix index with seed and rank using prime multiplication and bit operations
h = (
sequential_idx + self.seed + self.rank
) * 2654435761 # Large prime multiplier
h ^= h >> 16 # Mix higher bits into lower bits for better distribution
return h % self.dataset_size

def __iter__(self) -> Iterator[int]:
"""
Generate infinite sequence of dataset indices.

For distributed training, each rank gets every num_replicas-th element
from the sequence. The sequence can be either sequential or hash-randomized.

Yields
------
int
Dataset index for this rank
"""
global_idx = self.rank # Start at rank offset for proper DDP distribution

while True:
# Each rank gets different sequential indices: 0,2,4... vs 1,3,5...
sequential_idx = global_idx % self.dataset_size

if self.randomize:
# Apply hash-based randomization
yield self._hash_index(sequential_idx)
else:
# Return sequential index
yield sequential_idx

global_idx += self.num_replicas


# ----------------------------------------------------------------------------
# Utilities for operating with torch.nn.Module parameters and buffers.

Expand Down
Loading