Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 157 additions & 2 deletions merlin/models/torch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

import torch

from merlin.dataloader.torch import Loader
from merlin.io import Dataset


@torch.jit.script
class Sequence:
Expand Down Expand Up @@ -63,17 +66,57 @@ def __contains__(self, name: str) -> bool:
return name in self.lengths

def length(self, name: str = "default") -> torch.Tensor:
"""Retrieves a length tensor from a sequence by name.

Args:
name (str, optional): The name of the feature. Defaults to "default".

Returns:
torch.Tensor: The length tensor of the specified feature.

Raises:
ValueError: If the Sequence object has multiple lengths and
no feature name is specified.
"""

if name in self.lengths:
return self.lengths[name]

raise ValueError("Batch has multiple lengths, please specify a feature name")

def mask(self, name: str = "default") -> torch.Tensor:
"""Retrieves a mask tensor from a sequence by name.

Args:
name (str, optional): The name of the feature. Defaults to "default".

Returns:
torch.Tensor: The mask tensor of the specified feature.

Raises:
ValueError: If the Sequence object has multiple masks and
no feature name is specified.
"""
if name in self.masks:
return self.masks[name]

raise ValueError("Batch has multiple masks, please specify a feature name")

def device(self) -> torch.device:
"""Retrieves the device of the tensors in the Sequence object.

Returns:
torch.device: The device of the tensors.

Raises:
ValueError: If the Sequence object is empty.
"""
for d in self.lengths.values():
if isinstance(d, torch.Tensor):
return d.device

raise ValueError("Sequence is empty")


@torch.jit.script
class Batch:
Expand Down Expand Up @@ -123,6 +166,38 @@ def __init__(
self.targets: Dict[str, torch.Tensor] = _targets
self.sequences: Optional[Sequence] = sequences

@staticmethod
@torch.jit.ignore
def sample_from(
dataset_or_loader: Union[Dataset, Loader],
batch_size: int = 32,
shuffle: Optional[bool] = False,
) -> "Batch":
"""Sample a batch from a dataset or a loader.

Example usage::
dataset = merlin.io.Dataset(...)
batch = Batch.sample_from(dataset)

Parameters
----------
dataset_or_loader: merlin.io.dataset
A Dataset object or a Loader object.
batch_size: int, default=32
Number of samples to return.
shuffle: bool
Whether to sample a random batch or not, by default False.

Returns:
-------
features: Dict[torch.Tensor]
dictionary of feature tensors.
targets: Dict[torch.Tensor]
dictionary of target tensors.
"""

return sample_batch(dataset_or_loader, batch_size, shuffle)

def replace(
self,
features: Optional[Dict[str, torch.Tensor]] = None,
Expand Down Expand Up @@ -155,7 +230,7 @@ def replace(
)

def feature(self, name: str = "default") -> torch.Tensor:
"""Retrieve a feature tensor from the batch by its name.
"""Retrieve a feature tensor from the batch by name.

Parameters
----------
Expand All @@ -179,7 +254,7 @@ def feature(self, name: str = "default") -> torch.Tensor:
raise ValueError("Batch has multiple features, please specify a feature name")

def target(self, name: str = "default") -> torch.Tensor:
"""Retrieve a target tensor from the batch by its name.
"""Retrieve a target tensor from the batch by name.

Parameters
----------
Expand All @@ -204,3 +279,83 @@ def target(self, name: str = "default") -> torch.Tensor:

def __bool__(self) -> bool:
return bool(self.features)

def device(self) -> torch.device:
"""Retrieves the device of the tensors in the Batch object.

Returns:
torch.device: The device of the tensors.

Raises:
ValueError: If the Batch object is empty.
"""
for d in self.features.values():
if isinstance(d, torch.Tensor):
return d.device

raise ValueError("Batch is empty")


def sample_batch(
dataset_or_loader: Union[Dataset, Loader],
batch_size: Optional[int] = None,
shuffle: Optional[bool] = False,
) -> Batch:
"""Util function to generate a batch of input tensors from a merlin.io.Dataset instance

Parameters
----------
data: merlin.io.dataset
A Dataset object.
batch_size: int
Number of samples to return.
shuffle: bool
Whether to sample a random batch or not, by default False.

Returns:
-------
features: Dict[torch.Tensor]
dictionary of feature tensors.
targets: Dict[torch.Tensor]
dictionary of target tensors.
"""

if isinstance(dataset_or_loader, Dataset):
if not batch_size:
raise ValueError("Either use 'Loader' or specify 'batch_size'")
loader = Loader(dataset_or_loader, batch_size=batch_size, shuffle=shuffle)
elif isinstance(dataset_or_loader, Loader):
loader = dataset_or_loader
else:
raise ValueError(f"Expected Dataset or Loader instance, got: {dataset_or_loader}")

batch = loader.peek()
# batch could be of type Prediction, so we can't unpack directly
inputs, targets = batch[0], batch[1]

return Batch(inputs, targets)


def sample_features(
dataset_or_loader: Union[Dataset, Loader],
batch_size: Optional[int] = None,
shuffle: Optional[bool] = False,
) -> Dict[str, torch.Tensor]:
"""Util function to generate a dict of feature tensors from a merlin.io.Dataset instance

Parameters
----------
data: merlin.io.dataset
A Dataset object.
batch_size: int
Number of samples to return.
shuffle: bool
Whether to sample a random batch or not, by default False.

Returns:
-------
features: Dict[torch.Tensor]
dictionary of feature tensors.
"""

return sample_batch(dataset_or_loader, batch_size, shuffle).features
82 changes: 81 additions & 1 deletion tests/unit/torch/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import pytest
import torch

from merlin.models.torch.batch import Batch, Sequence
from merlin.dataloader.torch import Loader
from merlin.models.torch.batch import Batch, Sequence, sample_batch, sample_features


class TestSequence:
Expand Down Expand Up @@ -56,6 +57,7 @@ def test_init_tensor_lengths(self):
assert isinstance(sequence.lengths, dict)
assert "default" in sequence.lengths
assert torch.equal(sequence.lengths["default"], lengths)
assert sequence.device() == lengths.device

def test_init_tensor_masks(self):
# Test when masks is a tensor
Expand Down Expand Up @@ -90,6 +92,12 @@ def test_init_invalid_masks(self):
with pytest.raises(ValueError, match="Masks must be a tensor or a dictionary of tensors"):
Sequence(lengths, masks)

def test_device(self):
empty_seq = Sequence({})

with pytest.raises(ValueError, match="Sequence is empty"):
empty_seq.device()


class TestBatch:
@pytest.fixture
Expand Down Expand Up @@ -128,6 +136,7 @@ def test_batch_init_tensor_target(self):
assert isinstance(batch.targets, dict)
assert "default" in batch.targets
assert torch.equal(batch.targets["default"], targets)
assert batch.device() == features.device

def test_batch_init_invalid_targets(self):
# Test when targets is not a tensor nor a dictionary of tensors
Expand Down Expand Up @@ -155,3 +164,74 @@ def test_bool(self, batch):
def test_with_incorrect_types(self):
with pytest.raises(ValueError):
Batch("not a tensor or dict", "not a tensor or dict", "not a sequence")

def test_sample(self, music_streaming_data):
batch = Batch.sample_from(music_streaming_data)
assert isinstance(batch, Batch)

assert isinstance(batch.features, dict)
assert len(list(batch.features.keys())) == 12
for key, val in batch.features.items():
if not key.endswith("__values") and not key.endswith("__offsets"):
assert val.shape[0] == 32

assert isinstance(batch.targets, dict)
assert list(batch.targets.keys()) == ["click", "play_percentage", "like"]
for val in batch.targets.values():
assert val.shape[0] == 32

def test_device(self):
empty_batch = Batch({}, {})

with pytest.raises(ValueError, match="Batch is empty"):
empty_batch.device()


class Test_sample_batch:
def test_loader(self, music_streaming_data):
loader = Loader(music_streaming_data, batch_size=2)

batch = sample_batch(loader)

assert isinstance(batch.features, dict)
assert len(list(batch.features.keys())) == 12
for key, val in batch.features.items():
if not key.endswith("__values") and not key.endswith("__offsets"):
assert val.shape[0] == 2

assert isinstance(batch.targets, dict)
assert list(batch.targets.keys()) == ["click", "play_percentage", "like"]
for val in batch.targets.values():
assert val.shape[0] == 2

def test_dataset(self, music_streaming_data):
batch = sample_batch(music_streaming_data, batch_size=2)

assert isinstance(batch.features, dict)
assert len(list(batch.features.keys())) == 12
for key, val in batch.features.items():
if not key.endswith("__values") and not key.endswith("__offsets"):
assert val.shape[0] == 2

assert isinstance(batch.targets, dict)
assert list(batch.targets.keys()) == ["click", "play_percentage", "like"]
for val in batch.targets.values():
assert val.shape[0] == 2

def test_exceptions(self, music_streaming_data):
with pytest.raises(ValueError, match="specify 'batch_size'"):
sample_batch(music_streaming_data)

with pytest.raises(ValueError, match="Expected Dataset or Loader instance"):
sample_batch(torch.tensor([1, 2, 3]))


class Test_sample_features:
def test_no_targets(self, music_streaming_data):
features = sample_features(music_streaming_data, batch_size=2)

assert isinstance(features, dict)
assert len(list(features.keys())) == 12
for key, val in features.items():
if not key.endswith("__values") and not key.endswith("__offsets"):
assert val.shape[0] == 2