Skip to content

Commit 8b7e321

Browse files
committed
Increase test-coverage + doc-strings
1 parent 69d71f1 commit 8b7e321

File tree

2 files changed

+128
-3
lines changed

2 files changed

+128
-3
lines changed

merlin/models/torch/batch.py

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,57 @@ def __contains__(self, name: str) -> bool:
6666
return name in self.lengths
6767

6868
def length(self, name: str = "default") -> torch.Tensor:
69+
"""Retrieves a length tensor from a sequence by name.
70+
71+
Args:
72+
name (str, optional): The name of the feature. Defaults to "default".
73+
74+
Returns:
75+
torch.Tensor: The length tensor of the specified feature.
76+
77+
Raises:
78+
ValueError: If the Sequence object has multiple lengths and
79+
no feature name is specified.
80+
"""
81+
6982
if name in self.lengths:
7083
return self.lengths[name]
7184

7285
raise ValueError("Batch has multiple lengths, please specify a feature name")
7386

7487
def mask(self, name: str = "default") -> torch.Tensor:
88+
"""Retrieves a mask tensor from a sequence by name.
89+
90+
Args:
91+
name (str, optional): The name of the feature. Defaults to "default".
92+
93+
Returns:
94+
torch.Tensor: The mask tensor of the specified feature.
95+
96+
Raises:
97+
ValueError: If the Sequence object has multiple masks and
98+
no feature name is specified.
99+
"""
75100
if name in self.masks:
76101
return self.masks[name]
77102

78103
raise ValueError("Batch has multiple masks, please specify a feature name")
79104

105+
def device(self) -> torch.device:
106+
"""Retrieves the device of the tensors in the Sequence object.
107+
108+
Returns:
109+
torch.device: The device of the tensors.
110+
111+
Raises:
112+
ValueError: If the Sequence object is empty.
113+
"""
114+
for d in self.lengths.values():
115+
if isinstance(d, torch.Tensor):
116+
return d.device
117+
118+
raise ValueError("Sequence is empty")
119+
80120

81121
@torch.jit.script
82122
class Batch:
@@ -126,6 +166,38 @@ def __init__(
126166
self.targets: Dict[str, torch.Tensor] = _targets
127167
self.sequences: Optional[Sequence] = sequences
128168

169+
@staticmethod
170+
@torch.jit.ignore
171+
def sample_from(
172+
dataset_or_loader: Union[Dataset, Loader],
173+
batch_size: int = 32,
174+
shuffle: Optional[bool] = False,
175+
) -> "Batch":
176+
"""Sample a batch from a dataset or a loader.
177+
178+
Example usage::
179+
dataset = merlin.io.Dataset(...)
180+
batch = Batch.sample_from(dataset)
181+
182+
Parameters
183+
----------
184+
dataset_or_loader: merlin.io.dataset
185+
A Dataset object or a Loader object.
186+
batch_size: int, default=32
187+
Number of samples to return.
188+
shuffle: bool
189+
Whether to sample a random batch or not, by default False.
190+
191+
Returns:
192+
-------
193+
features: Dict[torch.Tensor]
194+
dictionary of feature tensors.
195+
targets: Dict[torch.Tensor]
196+
dictionary of target tensors.
197+
"""
198+
199+
return sample_batch(dataset_or_loader, batch_size, shuffle)
200+
129201
def replace(
130202
self,
131203
features: Optional[Dict[str, torch.Tensor]] = None,
@@ -158,7 +230,7 @@ def replace(
158230
)
159231

160232
def feature(self, name: str = "default") -> torch.Tensor:
161-
"""Retrieve a feature tensor from the batch by its name.
233+
"""Retrieve a feature tensor from the batch by name.
162234
163235
Parameters
164236
----------
@@ -182,7 +254,7 @@ def feature(self, name: str = "default") -> torch.Tensor:
182254
raise ValueError("Batch has multiple features, please specify a feature name")
183255

184256
def target(self, name: str = "default") -> torch.Tensor:
185-
"""Retrieve a target tensor from the batch by its name.
257+
"""Retrieve a target tensor from the batch by name.
186258
187259
Parameters
188260
----------
@@ -208,6 +280,21 @@ def target(self, name: str = "default") -> torch.Tensor:
208280
def __bool__(self) -> bool:
209281
return bool(self.features)
210282

283+
def device(self) -> torch.device:
284+
"""Retrieves the device of the tensors in the Batch object.
285+
286+
Returns:
287+
torch.device: The device of the tensors.
288+
289+
Raises:
290+
ValueError: If the Batch object is empty.
291+
"""
292+
for d in self.features.values():
293+
if isinstance(d, torch.Tensor):
294+
return d.device
295+
296+
raise ValueError("Batch is empty")
297+
211298

212299
def sample_batch(
213300
dataset_or_loader: Union[Dataset, Loader],
@@ -237,8 +324,10 @@ def sample_batch(
237324
if not batch_size:
238325
raise ValueError("Either use 'Loader' or specify 'batch_size'")
239326
loader = Loader(dataset_or_loader, batch_size=batch_size, shuffle=shuffle)
240-
else:
327+
elif isinstance(dataset_or_loader, Loader):
241328
loader = dataset_or_loader
329+
else:
330+
raise ValueError(f"Expected Dataset or Loader instance, got: {dataset_or_loader}")
242331

243332
batch = loader.peek()
244333
# batch could be of type Prediction, so we can't unpack directly

tests/unit/torch/test_batch.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_init_tensor_lengths(self):
5757
assert isinstance(sequence.lengths, dict)
5858
assert "default" in sequence.lengths
5959
assert torch.equal(sequence.lengths["default"], lengths)
60+
assert sequence.device() == lengths.device
6061

6162
def test_init_tensor_masks(self):
6263
# Test when masks is a tensor
@@ -91,6 +92,12 @@ def test_init_invalid_masks(self):
9192
with pytest.raises(ValueError, match="Masks must be a tensor or a dictionary of tensors"):
9293
Sequence(lengths, masks)
9394

95+
def test_device(self):
96+
empty_seq = Sequence({})
97+
98+
with pytest.raises(ValueError, match="Sequence is empty"):
99+
empty_seq.device()
100+
94101

95102
class TestBatch:
96103
@pytest.fixture
@@ -129,6 +136,7 @@ def test_batch_init_tensor_target(self):
129136
assert isinstance(batch.targets, dict)
130137
assert "default" in batch.targets
131138
assert torch.equal(batch.targets["default"], targets)
139+
assert batch.device() == features.device
132140

133141
def test_batch_init_invalid_targets(self):
134142
# Test when targets is not a tensor nor a dictionary of tensors
@@ -157,6 +165,27 @@ def test_with_incorrect_types(self):
157165
with pytest.raises(ValueError):
158166
Batch("not a tensor or dict", "not a tensor or dict", "not a sequence")
159167

168+
def test_sample(self, music_streaming_data):
169+
batch = Batch.sample_from(music_streaming_data)
170+
assert isinstance(batch, Batch)
171+
172+
assert isinstance(batch.features, dict)
173+
assert len(list(batch.features.keys())) == 12
174+
for key, val in batch.features.items():
175+
if not key.endswith("__values") and not key.endswith("__offsets"):
176+
assert val.shape[0] == 32
177+
178+
assert isinstance(batch.targets, dict)
179+
assert list(batch.targets.keys()) == ["click", "play_percentage", "like"]
180+
for val in batch.targets.values():
181+
assert val.shape[0] == 32
182+
183+
def test_device(self):
184+
empty_batch = Batch({}, {})
185+
186+
with pytest.raises(ValueError, match="Batch is empty"):
187+
empty_batch.device()
188+
160189

161190
class Test_sample_batch:
162191
def test_loader(self, music_streaming_data):
@@ -189,6 +218,13 @@ def test_dataset(self, music_streaming_data):
189218
for val in batch.targets.values():
190219
assert val.shape[0] == 2
191220

221+
def test_exceptions(self, music_streaming_data):
222+
with pytest.raises(ValueError, match="specify 'batch_size'"):
223+
sample_batch(music_streaming_data)
224+
225+
with pytest.raises(ValueError, match="Expected Dataset or Loader instance"):
226+
sample_batch(torch.tensor([1, 2, 3]))
227+
192228

193229
class Test_sample_features:
194230
def test_no_targets(self, music_streaming_data):

0 commit comments

Comments
 (0)