Skip to content

Commit 2921336

Browse files
committed
Output Batch instead
1 parent d00e67a commit 2921336

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

merlin/models/torch/batch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
#
1616

17-
from typing import Dict, Optional, Tuple, Union
17+
from typing import Dict, Optional, Union
1818

1919
import torch
2020

@@ -213,7 +213,7 @@ def sample_batch(
213213
dataset_or_loader: Union[Dataset, Loader],
214214
batch_size: Optional[int] = None,
215215
shuffle: Optional[bool] = False,
216-
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
216+
) -> Batch:
217217
"""Util function to generate a batch of input tensors from a merlin.io.Dataset instance
218218
219219
Parameters
@@ -244,7 +244,7 @@ def sample_batch(
244244
# batch could be of type Prediction, so we can't unpack directly
245245
inputs, targets = batch[0], batch[1]
246246

247-
return inputs, targets
247+
return Batch(inputs, targets)
248248

249249

250250
def sample_features(
@@ -269,4 +269,4 @@ def sample_features(
269269
dictionary of feature tensors.
270270
"""
271271

272-
return sample_batch(dataset_or_loader, batch_size, shuffle)[0]
272+
return sample_batch(dataset_or_loader, batch_size, shuffle).features

tests/unit/torch/test_batch.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -162,31 +162,31 @@ class Test_sample_batch:
162162
def test_loader(self, music_streaming_data):
163163
loader = Loader(music_streaming_data, batch_size=2)
164164

165-
features, targets = sample_batch(loader)
165+
batch = sample_batch(loader)
166166

167-
assert isinstance(features, dict)
168-
assert len(list(features.keys())) == 12
169-
for key, val in features.items():
167+
assert isinstance(batch.features, dict)
168+
assert len(list(batch.features.keys())) == 12
169+
for key, val in batch.features.items():
170170
if not key.endswith("__values") and not key.endswith("__offsets"):
171171
assert val.shape[0] == 2
172172

173-
assert isinstance(targets, dict)
174-
assert list(targets.keys()) == ["click", "play_percentage", "like"]
175-
for val in targets.values():
173+
assert isinstance(batch.targets, dict)
174+
assert list(batch.targets.keys()) == ["click", "play_percentage", "like"]
175+
for val in batch.targets.values():
176176
assert val.shape[0] == 2
177177

178178
def test_dataset(self, music_streaming_data):
179-
features, targets = sample_batch(music_streaming_data, batch_size=2)
179+
batch = sample_batch(music_streaming_data, batch_size=2)
180180

181-
assert isinstance(features, dict)
182-
assert len(list(features.keys())) == 12
183-
for key, val in features.items():
181+
assert isinstance(batch.features, dict)
182+
assert len(list(batch.features.keys())) == 12
183+
for key, val in batch.features.items():
184184
if not key.endswith("__values") and not key.endswith("__offsets"):
185185
assert val.shape[0] == 2
186186

187-
assert isinstance(targets, dict)
188-
assert list(targets.keys()) == ["click", "play_percentage", "like"]
189-
for val in targets.values():
187+
assert isinstance(batch.targets, dict)
188+
assert list(batch.targets.keys()) == ["click", "play_percentage", "like"]
189+
for val in batch.targets.values():
190190
assert val.shape[0] == 2
191191

192192

0 commit comments

Comments
 (0)