@@ -176,7 +176,7 @@ def _collate_tensor_dataset(data_list):
176176
177177 def _collate_graph_dataset (self , data_list ):
178178 """
179- Function used to collate the data when the dataset is a
179+ Function used to collate data when the dataset is a
180180 :class:`~pina.data.dataset.PinaGraphDataset`.
181181
182182 :param data_list: Elememts to be collated.
@@ -187,7 +187,6 @@ def _collate_graph_dataset(self, data_list):
187187 :raises RuntimeError: If the data is not a
188188 :class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`.
189189 """
190-
191190 if isinstance (data_list [0 ], LabelTensor ):
192191 return LabelTensor .cat (data_list )
193192 if isinstance (data_list [0 ], torch .Tensor ):
@@ -201,14 +200,13 @@ def _collate_graph_dataset(self, data_list):
201200
202201 def __call__ (self , batch ):
203202 """
204- Perform the collation of the data points fetched from the dataset.
205- The behavoior of the function is set based on the batching strategy
206- during class initialization.
203+ Perform the collation of data fetched from the dataset. The behavoior
204+ of the function is set based on the batching strategy during class
205+ initialization.
207206
208207 :param batch: List of retrieved data or sampled indices.
209208 :type batch: list[int] | list[dict]
210- :return: Dictionary containing the data points fetched from the dataset,
211- collated.
209+ :return: Dictionary containing colleted data fetched from the dataset.
212210 :rtype: dict
213211 """
214212
@@ -223,12 +221,10 @@ class PinaSampler:
223221
224222 def __new__ (cls , dataset , shuffle ):
225223 """
226- Instantiate the sampler based on the environment in which the code is
227- running.
224+ Instantiate and initialize the sampler.
228225
229- :param PinaDataset dataset: The dataset to be sampled.
230- :param bool shuffle: whether to shuffle the dataset or not before
231- sampling.
226+ :param PinaDataset dataset: The dataset from which to sample.
227+ :param bool shuffle: Whether to shuffle the dataset.
232228 :return: The sampler instance.
233229 :rtype: torch.utils.data.Sampler
234230 """
@@ -267,18 +263,18 @@ def __init__(
267263 pin_memory = False ,
268264 ):
269265 """
270- Initialize the object, creating datasets based on the input problem.
266+ Initialize the object and creating datasets based on the input problem.
271267
272268 :param AbstractProblem problem: The problem containing the data on which
273269 to create the datasets and dataloaders.
274- :param float train_size: Fraction or number of elements in the training
275- split. It must be in the range [0, 1].
276- :param float test_size: Fraction or number of elements in the test
277- split. It must be in the range [0, 1].
278- :param float val_size: Fraction or number of elements in the validation
279- split. It must be in the range [0, 1].
270+ :param float train_size: Fraction of elements in the training split. It
271+ must be in the range [0, 1].
272+ :param float test_size: Fraction of elements in the test split. It must
273+ be in the range [0, 1].
274+ :param float val_size: Fraction of elements in the validation split. It
275+ must be in the range [0, 1].
280276 :param batch_size: The batch size used for training. If `None`, the
281- entire dataset is used per batch.
277+ entire dataset is returned in a single batch.
282278 :type batch_size: int | None
283279 :param bool shuffle: Whether to shuffle the dataset before splitting.
284280 Default True.
@@ -289,7 +285,7 @@ def __init__(
289285 :param int num_workers: Number of worker threads for data loading.
290286 Default 0 (serial loading).
291287 :param bool pin_memory: Whether to use pinned memory for faster data
292- transfer to GPU. ( Default False) .
288+ transfer to GPU. Default False.
293289
294290 :raises ValueError: If at least one of the splits is negative.
295291 :raises ValueError: If the sum of the splits is different from 1.
@@ -370,7 +366,7 @@ def setup(self, stage=None):
370366 If the stage is "fit", the training and validation datasets are created.
371367 If the stage is "test", the testing dataset is created.
372368
373- :param str stage: The stage for which to perform the splitting .
369+ :param str stage: The stage for which to perform the dataset setup .
374370
375371 :raises ValueError: If the stage is neither "fit" nor "test".
376372 """
@@ -534,10 +530,10 @@ def _create_dataloader(self, split, dataset):
534530
535531 def find_max_conditions_lengths (self , split ):
536532 """
537- Define the maximum length of the conditions.
533+ Define the maximum length for each conditions.
538534
539- :param dict split: The splits of the dataset.
540- :return: The maximum length of the conditions .
535+ :param dict split: The split of the dataset.
536+ :return: The maximum length per condition .
541537 :rtype: dict
542538 """
543539
0 commit comments