@@ -66,17 +66,57 @@ def __contains__(self, name: str) -> bool:
6666 return name in self .lengths
6767
6868 def length (self , name : str = "default" ) -> torch .Tensor :
69+ """Retrieves a length tensor from a sequence by name.
70+
71+ Args:
72+ name (str, optional): The name of the feature. Defaults to "default".
73+
74+ Returns:
75+ torch.Tensor: The length tensor of the specified feature.
76+
77+ Raises:
78+ ValueError: If the Sequence object has multiple lengths and
79+ no feature name is specified.
80+ """
81+
6982 if name in self .lengths :
7083 return self .lengths [name ]
7184
7285 raise ValueError ("Batch has multiple lengths, please specify a feature name" )
7386
7487 def mask (self , name : str = "default" ) -> torch .Tensor :
88+ """Retrieves a mask tensor from a sequence by name.
89+
90+ Args:
91+ name (str, optional): The name of the feature. Defaults to "default".
92+
93+ Returns:
94+ torch.Tensor: The mask tensor of the specified feature.
95+
96+ Raises:
97+ ValueError: If the Sequence object has multiple masks and
98+ no feature name is specified.
99+ """
75100 if name in self .masks :
76101 return self .masks [name ]
77102
78103 raise ValueError ("Batch has multiple masks, please specify a feature name" )
79104
105+ def device (self ) -> torch .device :
106+ """Retrieves the device of the tensors in the Sequence object.
107+
108+ Returns:
109+ torch.device: The device of the tensors.
110+
111+ Raises:
112+ ValueError: If the Sequence object is empty.
113+ """
114+ for d in self .lengths .values ():
115+ if isinstance (d , torch .Tensor ):
116+ return d .device
117+
118+ raise ValueError ("Sequence is empty" )
119+
80120
81121@torch .jit .script
82122class Batch :
@@ -126,6 +166,38 @@ def __init__(
126166 self .targets : Dict [str , torch .Tensor ] = _targets
127167 self .sequences : Optional [Sequence ] = sequences
128168
169+ @staticmethod
170+ @torch .jit .ignore
171+ def sample_from (
172+ dataset_or_loader : Union [Dataset , Loader ],
173+ batch_size : int = 32 ,
174+ shuffle : Optional [bool ] = False ,
175+ ) -> "Batch" :
176+ """Sample a batch from a dataset or a loader.
177+
178+ Example usage::
179+ dataset = merlin.io.Dataset(...)
180+ batch = Batch.sample_from(dataset)
181+
182+ Parameters
183+ ----------
184+ dataset_or_loader: merlin.io.dataset
185+ A Dataset object or a Loader object.
186+ batch_size: int, default=32
187+ Number of samples to return.
188+ shuffle: bool
189+ Whether to sample a random batch or not, by default False.
190+
191+ Returns:
192+ -------
193+ features: Dict[torch.Tensor]
194+ dictionary of feature tensors.
195+ targets: Dict[torch.Tensor]
196+ dictionary of target tensors.
197+ """
198+
199+ return sample_batch (dataset_or_loader , batch_size , shuffle )
200+
129201 def replace (
130202 self ,
131203 features : Optional [Dict [str , torch .Tensor ]] = None ,
@@ -158,7 +230,7 @@ def replace(
158230 )
159231
160232 def feature (self , name : str = "default" ) -> torch .Tensor :
161- """Retrieve a feature tensor from the batch by its name.
233+ """Retrieve a feature tensor from the batch by name.
162234
163235 Parameters
164236 ----------
@@ -182,7 +254,7 @@ def feature(self, name: str = "default") -> torch.Tensor:
182254 raise ValueError ("Batch has multiple features, please specify a feature name" )
183255
184256 def target (self , name : str = "default" ) -> torch .Tensor :
185- """Retrieve a target tensor from the batch by its name.
257+ """Retrieve a target tensor from the batch by name.
186258
187259 Parameters
188260 ----------
@@ -208,6 +280,21 @@ def target(self, name: str = "default") -> torch.Tensor:
208280 def __bool__ (self ) -> bool :
209281 return bool (self .features )
210282
283+ def device (self ) -> torch .device :
284+ """Retrieves the device of the tensors in the Batch object.
285+
286+ Returns:
287+ torch.device: The device of the tensors.
288+
289+ Raises:
290+ ValueError: If the Batch object is empty.
291+ """
292+ for d in self .features .values ():
293+ if isinstance (d , torch .Tensor ):
294+ return d .device
295+
296+ raise ValueError ("Batch is empty" )
297+
211298
212299def sample_batch (
213300 dataset_or_loader : Union [Dataset , Loader ],
@@ -237,8 +324,10 @@ def sample_batch(
237324 if not batch_size :
238325 raise ValueError ("Either use 'Loader' or specify 'batch_size'" )
239326 loader = Loader (dataset_or_loader , batch_size = batch_size , shuffle = shuffle )
240- else :
327+ elif isinstance ( dataset_or_loader , Loader ) :
241328 loader = dataset_or_loader
329+ else :
330+ raise ValueError (f"Expected Dataset or Loader instance, got: { dataset_or_loader } " )
242331
243332 batch = loader .peek ()
244333 # batch could be of type Prediction, so we can't unpack directly
0 commit comments