1818
1919import torch
2020
21+ from merlin .dataloader .torch import Loader
22+ from merlin .io import Dataset
23+
2124
2225@torch .jit .script
2326class Sequence :
@@ -63,17 +66,57 @@ def __contains__(self, name: str) -> bool:
6366 return name in self .lengths
6467
6568 def length (self , name : str = "default" ) -> torch .Tensor :
69+ """Retrieves a length tensor from a sequence by name.
70+
71+ Args:
72+ name (str, optional): The name of the feature. Defaults to "default".
73+
74+ Returns:
75+ torch.Tensor: The length tensor of the specified feature.
76+
77+ Raises:
78+ ValueError: If the Sequence object has multiple lengths and
79+ no feature name is specified.
80+ """
81+
6682 if name in self .lengths :
6783 return self .lengths [name ]
6884
6985 raise ValueError ("Batch has multiple lengths, please specify a feature name" )
7086
7187 def mask (self , name : str = "default" ) -> torch .Tensor :
88+ """Retrieves a mask tensor from a sequence by name.
89+
90+ Args:
91+ name (str, optional): The name of the feature. Defaults to "default".
92+
93+ Returns:
94+ torch.Tensor: The mask tensor of the specified feature.
95+
96+ Raises:
97+ ValueError: If the Sequence object has multiple masks and
98+ no feature name is specified.
99+ """
72100 if name in self .masks :
73101 return self .masks [name ]
74102
75103 raise ValueError ("Batch has multiple masks, please specify a feature name" )
76104
105+ def device (self ) -> torch .device :
106+ """Retrieves the device of the tensors in the Sequence object.
107+
108+ Returns:
109+ torch.device: The device of the tensors.
110+
111+ Raises:
112+ ValueError: If the Sequence object is empty.
113+ """
114+ for d in self .lengths .values ():
115+ if isinstance (d , torch .Tensor ):
116+ return d .device
117+
118+ raise ValueError ("Sequence is empty" )
119+
77120
78121@torch .jit .script
79122class Batch :
@@ -123,6 +166,38 @@ def __init__(
123166 self .targets : Dict [str , torch .Tensor ] = _targets
124167 self .sequences : Optional [Sequence ] = sequences
125168
169+ @staticmethod
170+ @torch .jit .ignore
171+ def sample_from (
172+ dataset_or_loader : Union [Dataset , Loader ],
173+ batch_size : int = 32 ,
174+ shuffle : Optional [bool ] = False ,
175+ ) -> "Batch" :
176+ """Sample a batch from a dataset or a loader.
177+
178+ Example usage::
179+ dataset = merlin.io.Dataset(...)
180+ batch = Batch.sample_from(dataset)
181+
182+ Parameters
183+ ----------
184+ dataset_or_loader: merlin.io.dataset
185+ A Dataset object or a Loader object.
186+ batch_size: int, default=32
187+ Number of samples to return.
188+ shuffle: bool
189+ Whether to sample a random batch or not, by default False.
190+
191+ Returns:
192+ -------
193+ features: Dict[torch.Tensor]
194+ dictionary of feature tensors.
195+ targets: Dict[torch.Tensor]
196+ dictionary of target tensors.
197+ """
198+
199+ return sample_batch (dataset_or_loader , batch_size , shuffle )
200+
126201 def replace (
127202 self ,
128203 features : Optional [Dict [str , torch .Tensor ]] = None ,
@@ -155,7 +230,7 @@ def replace(
155230 )
156231
157232 def feature (self , name : str = "default" ) -> torch .Tensor :
158- """Retrieve a feature tensor from the batch by its name.
233+ """Retrieve a feature tensor from the batch by name.
159234
160235 Parameters
161236 ----------
@@ -179,7 +254,7 @@ def feature(self, name: str = "default") -> torch.Tensor:
179254 raise ValueError ("Batch has multiple features, please specify a feature name" )
180255
181256 def target (self , name : str = "default" ) -> torch .Tensor :
182- """Retrieve a target tensor from the batch by its name.
257+ """Retrieve a target tensor from the batch by name.
183258
184259 Parameters
185260 ----------
@@ -204,3 +279,83 @@ def target(self, name: str = "default") -> torch.Tensor:
204279
205280 def __bool__ (self ) -> bool :
206281 return bool (self .features )
282+
283+ def device (self ) -> torch .device :
284+ """Retrieves the device of the tensors in the Batch object.
285+
286+ Returns:
287+ torch.device: The device of the tensors.
288+
289+ Raises:
290+ ValueError: If the Batch object is empty.
291+ """
292+ for d in self .features .values ():
293+ if isinstance (d , torch .Tensor ):
294+ return d .device
295+
296+ raise ValueError ("Batch is empty" )
297+
298+
299+ def sample_batch (
300+ data : Union [Dataset , Loader ],
301+ batch_size : Optional [int ] = None ,
302+ shuffle : Optional [bool ] = False ,
303+ ) -> Batch :
304+ """Util function to generate a batch of input tensors from a merlin.io.Dataset instance
305+
306+ Parameters
307+ ----------
308+ data: merlin.io.dataset
309+ A Dataset object.
310+ batch_size: int
311+ Number of samples to return.
312+ shuffle: bool
313+ Whether to sample a random batch or not, by default False.
314+
315+ Returns:
316+ -------
317+ features: Dict[torch.Tensor]
318+ dictionary of feature tensors.
319+ targets: Dict[torch.Tensor]
320+ dictionary of target tensors.
321+ """
322+
323+ if isinstance (data , Dataset ):
324+ if not batch_size :
325+ raise ValueError ("Either use 'Loader' or specify 'batch_size'" )
326+ loader = Loader (data , batch_size = batch_size , shuffle = shuffle )
327+ elif isinstance (data , Loader ):
328+ loader = data
329+ else :
330+ raise ValueError (f"Expected Dataset or Loader instance, got: { data } " )
331+
332+ batch = loader .peek ()
333+ # batch could be of type Prediction, so we can't unpack directly
334+ inputs , targets = batch [0 ], batch [1 ]
335+
336+ return Batch (inputs , targets )
337+
338+
339+ def sample_features (
340+ data : Union [Dataset , Loader ],
341+ batch_size : Optional [int ] = None ,
342+ shuffle : Optional [bool ] = False ,
343+ ) -> Dict [str , torch .Tensor ]:
344+ """Util function to generate a dict of feature tensors from a merlin.io.Dataset instance
345+
346+ Parameters
347+ ----------
348+ data: merlin.io.dataset
349+ A Dataset object.
350+ batch_size: int
351+ Number of samples to return.
352+ shuffle: bool
353+ Whether to sample a random batch or not, by default False.
354+
355+ Returns:
356+ -------
357+ features: Dict[torch.Tensor]
358+ dictionary of feature tensors.
359+ """
360+
361+ return sample_batch (data , batch_size , shuffle ).features
0 commit comments