Skip to content

Commit 004ff77

Browse files
authored
Adding sample_batch & sample_features (#1095)
* Adding sample_batch & sample_features * Output Batch instead * Increase test-coverage + doc-strings * Rename to data as pointed out in PR-review
1 parent c8bcfa7 commit 004ff77

File tree

2 files changed

+238
-3
lines changed

2 files changed

+238
-3
lines changed

merlin/models/torch/batch.py

Lines changed: 157 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818

1919
import torch
2020

21+
from merlin.dataloader.torch import Loader
22+
from merlin.io import Dataset
23+
2124

2225
@torch.jit.script
2326
class Sequence:
@@ -63,17 +66,57 @@ def __contains__(self, name: str) -> bool:
6366
return name in self.lengths
6467

6568
def length(self, name: str = "default") -> torch.Tensor:
69+
"""Retrieves a length tensor from a sequence by name.
70+
71+
Args:
72+
name (str, optional): The name of the feature. Defaults to "default".
73+
74+
Returns:
75+
torch.Tensor: The length tensor of the specified feature.
76+
77+
Raises:
78+
ValueError: If the Sequence object has multiple lengths and
79+
no feature name is specified.
80+
"""
81+
6682
if name in self.lengths:
6783
return self.lengths[name]
6884

6985
raise ValueError("Batch has multiple lengths, please specify a feature name")
7086

7187
def mask(self, name: str = "default") -> torch.Tensor:
88+
"""Retrieves a mask tensor from a sequence by name.
89+
90+
Args:
91+
name (str, optional): The name of the feature. Defaults to "default".
92+
93+
Returns:
94+
torch.Tensor: The mask tensor of the specified feature.
95+
96+
Raises:
97+
ValueError: If the Sequence object has multiple masks and
98+
no feature name is specified.
99+
"""
72100
if name in self.masks:
73101
return self.masks[name]
74102

75103
raise ValueError("Batch has multiple masks, please specify a feature name")
76104

105+
def device(self) -> torch.device:
106+
"""Retrieves the device of the tensors in the Sequence object.
107+
108+
Returns:
109+
torch.device: The device of the tensors.
110+
111+
Raises:
112+
ValueError: If the Sequence object is empty.
113+
"""
114+
for d in self.lengths.values():
115+
if isinstance(d, torch.Tensor):
116+
return d.device
117+
118+
raise ValueError("Sequence is empty")
119+
77120

78121
@torch.jit.script
79122
class Batch:
@@ -123,6 +166,38 @@ def __init__(
123166
self.targets: Dict[str, torch.Tensor] = _targets
124167
self.sequences: Optional[Sequence] = sequences
125168

169+
@staticmethod
170+
@torch.jit.ignore
171+
def sample_from(
172+
dataset_or_loader: Union[Dataset, Loader],
173+
batch_size: int = 32,
174+
shuffle: Optional[bool] = False,
175+
) -> "Batch":
176+
"""Sample a batch from a dataset or a loader.
177+
178+
Example usage::
179+
dataset = merlin.io.Dataset(...)
180+
batch = Batch.sample_from(dataset)
181+
182+
Parameters
183+
----------
184+
dataset_or_loader: merlin.io.dataset
185+
A Dataset object or a Loader object.
186+
batch_size: int, default=32
187+
Number of samples to return.
188+
shuffle: bool
189+
Whether to sample a random batch or not, by default False.
190+
191+
Returns:
192+
-------
193+
features: Dict[torch.Tensor]
194+
dictionary of feature tensors.
195+
targets: Dict[torch.Tensor]
196+
dictionary of target tensors.
197+
"""
198+
199+
return sample_batch(dataset_or_loader, batch_size, shuffle)
200+
126201
def replace(
127202
self,
128203
features: Optional[Dict[str, torch.Tensor]] = None,
@@ -155,7 +230,7 @@ def replace(
155230
)
156231

157232
def feature(self, name: str = "default") -> torch.Tensor:
158-
"""Retrieve a feature tensor from the batch by its name.
233+
"""Retrieve a feature tensor from the batch by name.
159234
160235
Parameters
161236
----------
@@ -179,7 +254,7 @@ def feature(self, name: str = "default") -> torch.Tensor:
179254
raise ValueError("Batch has multiple features, please specify a feature name")
180255

181256
def target(self, name: str = "default") -> torch.Tensor:
182-
"""Retrieve a target tensor from the batch by its name.
257+
"""Retrieve a target tensor from the batch by name.
183258
184259
Parameters
185260
----------
@@ -204,3 +279,83 @@ def target(self, name: str = "default") -> torch.Tensor:
204279

205280
def __bool__(self) -> bool:
206281
return bool(self.features)
282+
283+
def device(self) -> torch.device:
284+
"""Retrieves the device of the tensors in the Batch object.
285+
286+
Returns:
287+
torch.device: The device of the tensors.
288+
289+
Raises:
290+
ValueError: If the Batch object is empty.
291+
"""
292+
for d in self.features.values():
293+
if isinstance(d, torch.Tensor):
294+
return d.device
295+
296+
raise ValueError("Batch is empty")
297+
298+
299+
def sample_batch(
300+
data: Union[Dataset, Loader],
301+
batch_size: Optional[int] = None,
302+
shuffle: Optional[bool] = False,
303+
) -> Batch:
304+
"""Util function to generate a batch of input tensors from a merlin.io.Dataset instance
305+
306+
Parameters
307+
----------
308+
data: merlin.io.dataset
309+
A Dataset object.
310+
batch_size: int
311+
Number of samples to return.
312+
shuffle: bool
313+
Whether to sample a random batch or not, by default False.
314+
315+
Returns:
316+
-------
317+
features: Dict[torch.Tensor]
318+
dictionary of feature tensors.
319+
targets: Dict[torch.Tensor]
320+
dictionary of target tensors.
321+
"""
322+
323+
if isinstance(data, Dataset):
324+
if not batch_size:
325+
raise ValueError("Either use 'Loader' or specify 'batch_size'")
326+
loader = Loader(data, batch_size=batch_size, shuffle=shuffle)
327+
elif isinstance(data, Loader):
328+
loader = data
329+
else:
330+
raise ValueError(f"Expected Dataset or Loader instance, got: {data}")
331+
332+
batch = loader.peek()
333+
# batch could be of type Prediction, so we can't unpack directly
334+
inputs, targets = batch[0], batch[1]
335+
336+
return Batch(inputs, targets)
337+
338+
339+
def sample_features(
340+
data: Union[Dataset, Loader],
341+
batch_size: Optional[int] = None,
342+
shuffle: Optional[bool] = False,
343+
) -> Dict[str, torch.Tensor]:
344+
"""Util function to generate a dict of feature tensors from a merlin.io.Dataset instance
345+
346+
Parameters
347+
----------
348+
data: merlin.io.dataset
349+
A Dataset object.
350+
batch_size: int
351+
Number of samples to return.
352+
shuffle: bool
353+
Whether to sample a random batch or not, by default False.
354+
355+
Returns:
356+
-------
357+
features: Dict[torch.Tensor]
358+
dictionary of feature tensors.
359+
"""
360+
361+
return sample_batch(data, batch_size, shuffle).features

tests/unit/torch/test_batch.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
import pytest
1818
import torch
1919

20-
from merlin.models.torch.batch import Batch, Sequence
20+
from merlin.dataloader.torch import Loader
21+
from merlin.models.torch.batch import Batch, Sequence, sample_batch, sample_features
2122

2223

2324
class TestSequence:
@@ -56,6 +57,7 @@ def test_init_tensor_lengths(self):
5657
assert isinstance(sequence.lengths, dict)
5758
assert "default" in sequence.lengths
5859
assert torch.equal(sequence.lengths["default"], lengths)
60+
assert sequence.device() == lengths.device
5961

6062
def test_init_tensor_masks(self):
6163
# Test when masks is a tensor
@@ -90,6 +92,12 @@ def test_init_invalid_masks(self):
9092
with pytest.raises(ValueError, match="Masks must be a tensor or a dictionary of tensors"):
9193
Sequence(lengths, masks)
9294

95+
def test_device(self):
96+
empty_seq = Sequence({})
97+
98+
with pytest.raises(ValueError, match="Sequence is empty"):
99+
empty_seq.device()
100+
93101

94102
class TestBatch:
95103
@pytest.fixture
@@ -128,6 +136,7 @@ def test_batch_init_tensor_target(self):
128136
assert isinstance(batch.targets, dict)
129137
assert "default" in batch.targets
130138
assert torch.equal(batch.targets["default"], targets)
139+
assert batch.device() == features.device
131140

132141
def test_batch_init_invalid_targets(self):
133142
# Test when targets is not a tensor nor a dictionary of tensors
@@ -155,3 +164,74 @@ def test_bool(self, batch):
155164
def test_with_incorrect_types(self):
156165
with pytest.raises(ValueError):
157166
Batch("not a tensor or dict", "not a tensor or dict", "not a sequence")
167+
168+
def test_sample(self, music_streaming_data):
169+
batch = Batch.sample_from(music_streaming_data)
170+
assert isinstance(batch, Batch)
171+
172+
assert isinstance(batch.features, dict)
173+
assert len(list(batch.features.keys())) == 12
174+
for key, val in batch.features.items():
175+
if not key.endswith("__values") and not key.endswith("__offsets"):
176+
assert val.shape[0] == 32
177+
178+
assert isinstance(batch.targets, dict)
179+
assert list(batch.targets.keys()) == ["click", "play_percentage", "like"]
180+
for val in batch.targets.values():
181+
assert val.shape[0] == 32
182+
183+
def test_device(self):
184+
empty_batch = Batch({}, {})
185+
186+
with pytest.raises(ValueError, match="Batch is empty"):
187+
empty_batch.device()
188+
189+
190+
class Test_sample_batch:
191+
def test_loader(self, music_streaming_data):
192+
loader = Loader(music_streaming_data, batch_size=2)
193+
194+
batch = sample_batch(loader)
195+
196+
assert isinstance(batch.features, dict)
197+
assert len(list(batch.features.keys())) == 12
198+
for key, val in batch.features.items():
199+
if not key.endswith("__values") and not key.endswith("__offsets"):
200+
assert val.shape[0] == 2
201+
202+
assert isinstance(batch.targets, dict)
203+
assert list(batch.targets.keys()) == ["click", "play_percentage", "like"]
204+
for val in batch.targets.values():
205+
assert val.shape[0] == 2
206+
207+
def test_dataset(self, music_streaming_data):
208+
batch = sample_batch(music_streaming_data, batch_size=2)
209+
210+
assert isinstance(batch.features, dict)
211+
assert len(list(batch.features.keys())) == 12
212+
for key, val in batch.features.items():
213+
if not key.endswith("__values") and not key.endswith("__offsets"):
214+
assert val.shape[0] == 2
215+
216+
assert isinstance(batch.targets, dict)
217+
assert list(batch.targets.keys()) == ["click", "play_percentage", "like"]
218+
for val in batch.targets.values():
219+
assert val.shape[0] == 2
220+
221+
def test_exceptions(self, music_streaming_data):
222+
with pytest.raises(ValueError, match="specify 'batch_size'"):
223+
sample_batch(music_streaming_data)
224+
225+
with pytest.raises(ValueError, match="Expected Dataset or Loader instance"):
226+
sample_batch(torch.tensor([1, 2, 3]))
227+
228+
229+
class Test_sample_features:
230+
def test_no_targets(self, music_streaming_data):
231+
features = sample_features(music_streaming_data, batch_size=2)
232+
233+
assert isinstance(features, dict)
234+
assert len(list(features.keys())) == 12
235+
for key, val in features.items():
236+
if not key.endswith("__values") and not key.endswith("__offsets"):
237+
assert val.shape[0] == 2

0 commit comments

Comments
 (0)