Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions pina/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
automatic_batching=None,
num_workers=None,
pin_memory=None,
shuffle=None,
**kwargs,
):
"""
Expand All @@ -34,13 +35,13 @@ def __init__(
If ``batch_size=None`` all
samples are loaded and data are not batched, defaults to None.
:type batch_size: int | None
:param train_size: percentage of elements in the train dataset
:param train_size: Percentage of elements in the train dataset.
:type train_size: float
:param test_size: percentage of elements in the test dataset
:param test_size: Percentage of elements in the test dataset.
:type test_size: float
:param val_size: percentage of elements in the val dataset
:param val_size: Percentage of elements in the val dataset.
:type val_size: float
:param predict_size: percentage of elements in the predict dataset
:param predict_size: Percentage of elements in the predict dataset.
:type predict_size: float
:param compile: if True model is compiled before training,
default False. For Windows users compilation is always disabled.
Expand All @@ -49,9 +50,13 @@ def __init__(
performed. Please avoid using automatic batching when batch_size is
large, default False.
:type automatic_batching: bool
:param num_workers: Number of worker threads for data loading. Default 0 (serial loading)
:param num_workers: Number of worker threads for data loading.
Default 0 (serial loading).
:type num_workers: int
:param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False)
:param pin_memory: Whether to use pinned memory for faster data
transfer to GPU. Default False.
:type pin_memory: bool
:param shuffle: Whether to shuffle the data for training. Default False.
:type pin_memory: bool

:Keyword Arguments:
Expand All @@ -77,6 +82,10 @@ def __init__(
check_consistency(pin_memory, int)
else:
num_workers = 0
if shuffle is not None:
check_consistency(shuffle, bool)
else:
shuffle = False
if train_size + test_size + val_size + predict_size > 1:
raise ValueError(
"train_size, test_size, val_size and predict_size "
Expand Down Expand Up @@ -131,6 +140,7 @@ def __init__(
automatic_batching,
pin_memory,
num_workers,
shuffle,
)

# logging
Expand Down Expand Up @@ -166,6 +176,7 @@ def _create_datamodule(
automatic_batching,
pin_memory,
num_workers,
shuffle,
):
"""
This method is used here because is resampling is needed
Expand Down Expand Up @@ -196,6 +207,7 @@ def _create_datamodule(
automatic_batching=automatic_batching,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=shuffle,
)

def train(self, **kwargs):
Expand Down
Loading