Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 193 additions & 48 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,49 @@ def _do_nothing(*_: Any) -> None:
class Fabric:
r"""Fabric accelerates your PyTorch training or inference code with minimal changes required.

- Automatic placement of models and data onto the device.
- Automatic support for mixed and double precision (smaller memory footprint).
- Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies
(data-parallel training, sharded training, etc.).
- Automated spawning of processes, no launch utilities required.
- Multi-node support.
Features:
- Automatic placement of models and data onto the device.
- Automatic support for mixed and double precision (smaller memory footprint).
- Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies
(data-parallel training, sharded training, etc.).
- Automated spawning of processes, no launch utilities required.
- Multi-node support.

Args:
accelerator: The hardware to run on. Possible choices are:
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
Defaults to ``"auto"``.
strategy: Strategy for how to run across multiple devices. Possible choices are:
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``, ``"auto"``.
Defaults to ``"auto"``.
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
The value applies per node.
num_nodes: Number of GPU nodes for distributed training.
The value applies per node. Defaults to ``"auto"``.
num_nodes: Number of GPU nodes for distributed training. Defaults to ``1``.
precision: Double precision (``"64"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),
or bfloat16 precision AMP (``"bf16-mixed"``).
plugins: One or several custom plugins
or bfloat16 precision AMP (``"bf16-mixed"``). If ``None``, defaults will be used based on the device.
plugins: One or several custom plugins as a single plugin or list of plugins.
callbacks: A single callback or a list of callbacks. A callback can contain any arbitrary methods that
can be invoked through :meth:`~lightning.fabric.fabric.Fabric.call` by the user.
loggers: A single logger or a list of loggers. See :meth:`~lightning.fabric.fabric.Fabric.log` for more
information.

Example::

# Basic usage
fabric = Fabric(accelerator="gpu", devices=2)

# Set up model and optimizer
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = fabric.setup(model, optimizer)

# Training loop
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch)
fabric.backward(loss)
optimizer.step()

"""

def __init__(
Expand Down Expand Up @@ -217,9 +237,9 @@ def setup(
r"""Set up a model and its optimizers for accelerated training.

Args:
module: A :class:`torch.nn.Module` to set up
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
scheduler: The learning rate scheduler to set up (no learning rate scheduler is also possible)
module: A :class:`torch.nn.Module` to set up.
*optimizers: The optimizer(s) to set up. Can be zero or more optimizers.
scheduler: An optional learning rate scheduler to set up. Must be provided after optimizers if used.
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
and alternatively use :meth:`to_device` manually.
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
Expand All @@ -228,8 +248,24 @@ def setup(
FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.

Returns:
The tuple containing wrapped module, optimizers, and an optional learning rate scheduler,
in the same order they were passed in.
If no optimizers are passed, returns the wrapped module. If optimizers are passed, returns a tuple
containing the wrapped module and optimizers, and optionally the scheduler if provided, in the same
order they were passed in.

Note:
For certain strategies like FSDP, you may need to set up the model first using :meth:`setup_module`,
then create the optimizer, and finally set up the optimizer using :meth:`setup_optimizers`.

Example::

# Basic usage
model, optimizer = fabric.setup(model, optimizer)

# With multiple optimizers and scheduler
model, opt1, opt2, scheduler = fabric.setup(model, opt1, opt2, scheduler=scheduler)

# Model only
model = fabric.setup(model)

"""
self._validate_setup(module, optimizers)
Expand Down Expand Up @@ -286,15 +322,25 @@ def setup_module(
See also :meth:`setup_optimizers`.

Args:
module: A :class:`torch.nn.Module` to set up
module: A :class:`torch.nn.Module` to set up.
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
and alternatively use :meth:`to_device` manually.
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.

Returns:
The wrapped model.
The wrapped model as a :class:`~lightning.fabric.wrappers._FabricModule`.

Example::

# Set up model first (useful for FSDP)
model = fabric.setup_module(model)

# Then create and set up optimizer
optimizer = torch.optim.Adam(model.parameters())
optimizer = fabric.setup_optimizers(optimizer)

"""
self._validate_setup_module(module)
Expand Down Expand Up @@ -334,10 +380,26 @@ def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, tu
``.setup(model, optimizer, ...)`` instead to jointly set them up.

Args:
*optimizers: One or more optimizers to set up.
*optimizers: One or more optimizers to set up. Must provide at least one optimizer.

Returns:
The wrapped optimizer(s).
If a single optimizer is passed, returns the wrapped optimizer. If multiple optimizers are passed,
returns a tuple of wrapped optimizers in the same order they were passed in.

Raises:
RuntimeError: If using DeepSpeed or XLA strategies, which require joint model-optimizer setup.
ValueError: If no optimizers are provided.

Note:
This method cannot be used with DeepSpeed or XLA strategies. Use :meth:`setup` instead for those strategies.

Example::

# Single optimizer
optimizer = fabric.setup_optimizers(optimizer)

# Multiple optimizers
opt1, opt2 = fabric.setup_optimizers(opt1, opt2)

"""
self._validate_setup_optimizers(optimizers)
Expand All @@ -355,7 +417,8 @@ def setup_dataloaders(
dataloader, call this method individually for each one.

Args:
*dataloaders: A single dataloader or a sequence of dataloaders.
*dataloaders: One or more PyTorch :class:`~torch.utils.data.DataLoader` instances to set up.
Must provide at least one dataloader.
use_distributed_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the
dataloader(s) for distributed training. If you have a custom sampler defined, set this argument
to ``False``.
Expand All @@ -364,7 +427,16 @@ def setup_dataloaders(
returned data.

Returns:
The wrapped dataloaders, in the same order they were passed in.
If a single dataloader is passed, returns the wrapped dataloader. If multiple dataloaders are passed,
returns a list of wrapped dataloaders in the same order they were passed in.

Example::

# Single dataloader
train_loader = fabric.setup_dataloaders(train_loader)

# Multiple dataloaders
train_loader, val_loader = fabric.setup_dataloaders(train_loader, val_loader)

"""
self._validate_setup_dataloaders(dataloaders)
Expand Down Expand Up @@ -410,18 +482,27 @@ def _setup_dataloader(
return fabric_dataloader

def backward(self, tensor: Tensor, *args: Any, model: Optional[_FabricModule] = None, **kwargs: Any) -> None:
r"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.
r"""Replaces ``loss.backward()`` in your training loop. Handles precision automatically for you.

Args:
tensor: The tensor (loss) to back-propagate gradients from.
*args: Optional positional arguments passed to the underlying backward function.
model: Optional model instance for plugins that require the model for backward().
model: Optional model instance for plugins that require the model for backward(). Required when using
DeepSpeed strategy with multiple models.
**kwargs: Optional named keyword arguments passed to the underlying backward function.

Note:
When using ``strategy="deepspeed"`` and multiple models were set up, it is required to pass in the
model as argument here.

Example::

loss = criterion(output, target)
fabric.backward(loss)

# With DeepSpeed and multiple models
fabric.backward(loss, model=model)

"""
module = model._forward_module if model is not None else model
module, _ = _unwrap_compiled(module)
Expand Down Expand Up @@ -459,17 +540,29 @@ def clip_gradients(
Args:
module: The module whose parameters should be clipped.
optimizer: The optimizer referencing the parameters to be clipped.
clip_val: If passed, gradients will be clipped to this value.
clip_val: If passed, gradients will be clipped to this value. Cannot be used together with ``max_norm``.
max_norm: If passed, clips the gradients in such a way that the p-norm of the resulting parameters is
no larger than the given value.
norm_type: The type of norm if `max_norm` was passed. Can be ``'inf'`` for infinity norm.
Default is the 2-norm.
no larger than the given value. Cannot be used together with ``clip_val``.
norm_type: The type of norm if ``max_norm`` was passed. Can be ``'inf'`` for infinity norm.
Defaults to 2-norm.
error_if_nonfinite: An error is raised if the total norm of the gradients is NaN or infinite.
Only applies when ``max_norm`` is used.

Return:
Returns:
The total norm of the gradients (before clipping was applied) as a scalar tensor if ``max_norm`` was
passed, otherwise ``None``.

Raises:
ValueError: If both ``clip_val`` and ``max_norm`` are provided, or if neither is provided.

Example::

# Clip by value
fabric.clip_gradients(model, optimizer, clip_val=1.0)

# Clip by norm
total_norm = fabric.clip_gradients(model, optimizer, max_norm=1.0)

"""
if clip_val is not None and max_norm is not None:
raise ValueError(
Expand Down Expand Up @@ -643,24 +736,37 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Abstr
r"""Skip gradient synchronization during backward to avoid redundant communication overhead.

Use this context manager when performing gradient accumulation to speed up training with multiple devices.

Example::

# Accumulate gradient 8 batches at a time
with fabric.no_backward_sync(model, enabled=(batch_idx % 8 != 0)):
output = model(input)
loss = ...
fabric.backward(loss)
...

For those strategies that don't support it, a warning is emitted. For single-device strategies, it is a no-op.
Both the model's ``.forward()`` and the ``fabric.backward()`` call need to run under this context.

Args:
module: The module for which to control the gradient synchronization.
module: The module for which to control the gradient synchronization. Must be a module that was
set up with :meth:`setup` or :meth:`setup_module`.
enabled: Whether the context manager is enabled or not. ``True`` means skip the sync, ``False`` means do not
skip.

Returns:
A context manager that controls gradient synchronization.

Raises:
TypeError: If the module was not set up with Fabric first.

Note:
For strategies that don't support gradient sync control, a warning is emitted and the context manager
becomes a no-op. For single-device strategies, it is always a no-op.

Example::

# Accumulate gradients over 8 batches
for batch_idx, batch in enumerate(dataloader):
with fabric.no_backward_sync(model, enabled=(batch_idx % 8 != 0)):
output = model(batch)
loss = criterion(output, target)
fabric.backward(loss)

if batch_idx % 8 == 0:
optimizer.step()
optimizer.zero_grad()

"""
module, _ = _unwrap_compiled(module)
if not isinstance(module, _FabricModule):
Expand Down Expand Up @@ -726,13 +832,28 @@ def save(
This method must be called on all processes!

Args:
path: A path to where the file(s) should be saved
path: A path to where the file(s) should be saved.
state: A dictionary with contents to be saved. If the dict contains modules or optimizers, their
state-dict will be retrieved and converted automatically.
filter: An optional dictionary containing filter callables that return a boolean indicating whether the
given item should be saved (``True``) or filtered out (``False``). Each filter key should match a
state key, where its filter will be applied to the ``state_dict`` generated.

Raises:
TypeError: If filter is not a dictionary or contains non-callable values.
ValueError: If filter keys don't match state keys.

Example::

state = {"model": model, "optimizer": optimizer, "epoch": epoch}
fabric.save("checkpoint.pth", state)

# With filter
def param_filter(name, param):
return "bias" not in name # Save only non-bias parameters

fabric.save("checkpoint.pth", state, filter={"model": param_filter})

"""
if filter is not None:
if not isinstance(filter, dict):
Expand All @@ -759,7 +880,7 @@ def load(
This method must be called on all processes!

Args:
path: A path to where the file is located
path: A path to where the file is located.
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
If no state is given, then the checkpoint will be returned in full.
strict: Whether to enforce that the keys in `state` match the keys in the checkpoint.
Expand All @@ -768,6 +889,16 @@ def load(
The remaining items that were not restored into the given state dictionary. If no state dictionary is
given, the full checkpoint will be returned.

Example::

# Load full checkpoint
checkpoint = fabric.load("checkpoint.pth")

# Load into existing objects
state = {"model": model, "optimizer": optimizer}
remainder = fabric.load("checkpoint.pth", state)
epoch = remainder.get("epoch", 0)

"""
unwrapped_state = _unwrap_objects(state)
remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state, strict=strict)
Expand Down Expand Up @@ -805,18 +936,32 @@ def launch(self, function: Callable[["Fabric"], Any] = _do_nothing, *args: Any,
Args:
function: Optional function to launch when using a spawn/fork-based strategy, for example, when using the
XLA strategy (``accelerator="tpu"``). The function must accept at least one argument, to which
the Fabric object itself will be passed.
the Fabric object itself will be passed. If not provided, only process initialization will be performed.
*args: Optional positional arguments to be passed to the function.
**kwargs: Optional keyword arguments to be passed to the function.

Returns:
Returns the output of the function that ran in worker process with rank 0.

The ``launch()`` method should only be used if you intend to specify accelerator, devices, and so on in
the code (programmatically). If you are launching with the Lightning CLI, ``fabric run ...``, remove
``launch()`` from your code.
Raises:
RuntimeError: If called when script was launched through the CLI.
TypeError: If function is provided but not callable, or if function doesn't accept required arguments.

Note:
The ``launch()`` method should only be used if you intend to specify accelerator, devices, and so on in
the code (programmatically). If you are launching with the Lightning CLI, ``fabric run ...``, remove
``launch()`` from your code.

The ``launch()`` is a no-op when called multiple times and no function is passed in.

Example::

def train_function(fabric):
model, optimizer = fabric.setup(model, optimizer)
# ... training code ...

The ``launch()`` is a no-op when called multiple times and no function is passed in.
fabric = Fabric(accelerator="tpu", devices=8)
fabric.launch(train_function)

"""
if _is_using_cli():
Expand Down
Loading