@@ -86,29 +86,49 @@ def _do_nothing(*_: Any) -> None:
86
86
class Fabric :
87
87
r"""Fabric accelerates your PyTorch training or inference code with minimal changes required.
88
88
89
- - Automatic placement of models and data onto the device.
90
- - Automatic support for mixed and double precision (smaller memory footprint).
91
- - Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies
92
- (data-parallel training, sharded training, etc.).
93
- - Automated spawning of processes, no launch utilities required.
94
- - Multi-node support.
89
+ Key Features:
90
+ - Automatic placement of models and data onto the device.
91
+ - Automatic support for mixed and double precision (smaller memory footprint).
92
+ - Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies
93
+ (data-parallel training, sharded training, etc.).
94
+ - Automated spawning of processes, no launch utilities required.
95
+ - Multi-node support.
95
96
96
97
Args:
97
98
accelerator: The hardware to run on. Possible choices are:
98
99
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
100
+ Defaults to ``"auto"``.
99
101
strategy: Strategy for how to run across multiple devices. Possible choices are:
100
- ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.
102
+ ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``, ``"auto"``.
103
+ Defaults to ``"auto"``.
101
104
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
102
- The value applies per node.
103
- num_nodes: Number of GPU nodes for distributed training.
105
+ The value applies per node. Defaults to ``"auto"``.
106
+ num_nodes: Number of GPU nodes for distributed training. Defaults to ``1``.
104
107
precision: Double precision (``"64"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),
105
- or bfloat16 precision AMP (``"bf16-mixed"``).
106
- plugins: One or several custom plugins
108
+ or bfloat16 precision AMP (``"bf16-mixed"``). If ``None``, defaults will be used based on the device.
109
+ plugins: One or several custom plugins as a single plugin or list of plugins.
107
110
callbacks: A single callback or a list of callbacks. A callback can contain any arbitrary methods that
108
111
can be invoked through :meth:`~lightning.fabric.fabric.Fabric.call` by the user.
109
112
loggers: A single logger or a list of loggers. See :meth:`~lightning.fabric.fabric.Fabric.log` for more
110
113
information.
111
114
115
+ Example::
116
+
117
+ # Basic usage
118
+ fabric = Fabric(accelerator="gpu", devices=2)
119
+
120
+ # Set up model and optimizer
121
+ model = MyModel()
122
+ optimizer = torch.optim.Adam(model.parameters())
123
+ model, optimizer = fabric.setup(model, optimizer)
124
+
125
+ # Training loop
126
+ for batch in dataloader:
127
+ optimizer.zero_grad()
128
+ loss = model(batch)
129
+ fabric.backward(loss)
130
+ optimizer.step()
131
+
112
132
"""
113
133
114
134
def __init__ (
@@ -217,9 +237,9 @@ def setup(
217
237
r"""Set up a model and its optimizers for accelerated training.
218
238
219
239
Args:
220
- module: A :class:`torch.nn.Module` to set up
221
- *optimizers: The optimizer(s) to set up (no optimizers is also possible)
222
- scheduler: The learning rate scheduler to set up (no learning rate scheduler is also possible)
240
+ module: A :class:`torch.nn.Module` to set up.
241
+ *optimizers: The optimizer(s) to set up. Can be zero or more optimizers.
242
+ scheduler: An optional learning rate scheduler to set up. Must be provided after optimizers if used.
223
243
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
224
244
and alternatively use :meth:`to_device` manually.
225
245
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
@@ -228,8 +248,24 @@ def setup(
228
248
FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.
229
249
230
250
Returns:
231
- The tuple containing wrapped module, optimizers, and an optional learning rate scheduler,
232
- in the same order they were passed in.
251
+ If no optimizers are passed, returns the wrapped module. If optimizers are passed, returns a tuple
252
+ containing the wrapped module and optimizers, and optionally the scheduler if provided, in the same
253
+ order they were passed in.
254
+
255
+ Note:
256
+ For certain strategies like FSDP, you may need to set up the model first using :meth:`setup_module`,
257
+ then create the optimizer, and finally set up the optimizer using :meth:`setup_optimizers`.
258
+
259
+ Example::
260
+
261
+ # Basic usage
262
+ model, optimizer = fabric.setup(model, optimizer)
263
+
264
+ # With multiple optimizers and scheduler
265
+ model, opt1, opt2, scheduler = fabric.setup(model, opt1, opt2, scheduler=scheduler)
266
+
267
+ # Model only
268
+ model = fabric.setup(model)
233
269
234
270
"""
235
271
self ._validate_setup (module , optimizers )
@@ -286,15 +322,25 @@ def setup_module(
286
322
See also :meth:`setup_optimizers`.
287
323
288
324
Args:
289
- module: A :class:`torch.nn.Module` to set up
325
+ module: A :class:`torch.nn.Module` to set up.
290
326
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
291
327
and alternatively use :meth:`to_device` manually.
292
328
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
293
329
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
294
330
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
295
331
FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.
332
+
296
333
Returns:
297
- The wrapped model.
334
+ The wrapped model as a :class:`~lightning.fabric.wrappers._FabricModule`.
335
+
336
+ Example::
337
+
338
+ # Set up model first (useful for FSDP)
339
+ model = fabric.setup_module(model)
340
+
341
+ # Then create and set up optimizer
342
+ optimizer = torch.optim.Adam(model.parameters())
343
+ optimizer = fabric.setup_optimizers(optimizer)
298
344
299
345
"""
300
346
self ._validate_setup_module (module )
@@ -334,10 +380,25 @@ def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, tu
334
380
``.setup(model, optimizer, ...)`` instead to jointly set them up.
335
381
336
382
Args:
337
- *optimizers: One or more optimizers to set up.
383
+ *optimizers: One or more optimizers to set up. Must provide at least one optimizer.
338
384
339
385
Returns:
340
- The wrapped optimizer(s).
386
+ If a single optimizer is passed, returns the wrapped optimizer. If multiple optimizers are passed,
387
+ returns a tuple of wrapped optimizers in the same order they were passed in.
388
+
389
+ Raises:
390
+ RuntimeError: If using DeepSpeed or XLA strategies, which require joint model-optimizer setup.
391
+
392
+ Note:
393
+ This method cannot be used with DeepSpeed or XLA strategies. Use :meth:`setup` instead for those strategies.
394
+
395
+ Example::
396
+
397
+ # Single optimizer
398
+ optimizer = fabric.setup_optimizers(optimizer)
399
+
400
+ # Multiple optimizers
401
+ opt1, opt2 = fabric.setup_optimizers(opt1, opt2)
341
402
342
403
"""
343
404
self ._validate_setup_optimizers (optimizers )
@@ -355,7 +416,7 @@ def setup_dataloaders(
355
416
dataloader, call this method individually for each one.
356
417
357
418
Args:
358
- *dataloaders: A single dataloader or a sequence of dataloaders .
419
+ *dataloaders: One or more PyTorch :class:`~torch.utils.data.DataLoader` instances to set up .
359
420
use_distributed_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the
360
421
dataloader(s) for distributed training. If you have a custom sampler defined, set this argument
361
422
to ``False``.
@@ -364,7 +425,16 @@ def setup_dataloaders(
364
425
returned data.
365
426
366
427
Returns:
367
- The wrapped dataloaders, in the same order they were passed in.
428
+ If a single dataloader is passed, returns the wrapped dataloader. If multiple dataloaders are passed,
429
+ returns a list of wrapped dataloaders in the same order they were passed in.
430
+
431
+ Example::
432
+
433
+ # Single dataloader
434
+ train_loader = fabric.setup_dataloaders(train_loader)
435
+
436
+ # Multiple dataloaders
437
+ train_loader, val_loader = fabric.setup_dataloaders(train_loader, val_loader)
368
438
369
439
"""
370
440
self ._validate_setup_dataloaders (dataloaders )
@@ -410,18 +480,27 @@ def _setup_dataloader(
410
480
return fabric_dataloader
411
481
412
482
def backward (self , tensor : Tensor , * args : Any , model : Optional [_FabricModule ] = None , ** kwargs : Any ) -> None :
413
- r"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.
483
+ r"""Replaces ``loss.backward()`` in your training loop. Handles precision automatically for you.
414
484
415
485
Args:
416
486
tensor: The tensor (loss) to back-propagate gradients from.
417
487
*args: Optional positional arguments passed to the underlying backward function.
418
- model: Optional model instance for plugins that require the model for backward().
488
+ model: Optional model instance for plugins that require the model for backward(). Required when using
489
+ DeepSpeed strategy with multiple models.
419
490
**kwargs: Optional named keyword arguments passed to the underlying backward function.
420
491
421
492
Note:
422
493
When using ``strategy="deepspeed"`` and multiple models were set up, it is required to pass in the
423
494
model as argument here.
424
495
496
+ Example::
497
+
498
+ loss = criterion(output, target)
499
+ fabric.backward(loss)
500
+
501
+ # With DeepSpeed and multiple models
502
+ fabric.backward(loss, model=model)
503
+
425
504
"""
426
505
module = model ._forward_module if model is not None else model
427
506
module , _ = _unwrap_compiled (module )
@@ -459,17 +538,29 @@ def clip_gradients(
459
538
Args:
460
539
module: The module whose parameters should be clipped.
461
540
optimizer: The optimizer referencing the parameters to be clipped.
462
- clip_val: If passed, gradients will be clipped to this value.
541
+ clip_val: If passed, gradients will be clipped to this value. Cannot be used together with ``max_norm``.
463
542
max_norm: If passed, clips the gradients in such a way that the p-norm of the resulting parameters is
464
- no larger than the given value.
465
- norm_type: The type of norm if `max_norm` was passed. Can be ``'inf'`` for infinity norm.
466
- Default is the 2-norm.
543
+ no larger than the given value. Cannot be used together with ``clip_val``.
544
+ norm_type: The type of norm if `` max_norm` ` was passed. Can be ``'inf'`` for infinity norm.
545
+ Defaults to 2-norm.
467
546
error_if_nonfinite: An error is raised if the total norm of the gradients is NaN or infinite.
547
+ Only applies when ``max_norm`` is used.
468
548
469
- Return :
549
+ Returns :
470
550
The total norm of the gradients (before clipping was applied) as a scalar tensor if ``max_norm`` was
471
551
passed, otherwise ``None``.
472
552
553
+ Raises:
554
+ ValueError: If both ``clip_val`` and ``max_norm`` are provided, or if neither is provided.
555
+
556
+ Example::
557
+
558
+ # Clip by value
559
+ fabric.clip_gradients(model, optimizer, clip_val=1.0)
560
+
561
+ # Clip by norm
562
+ total_norm = fabric.clip_gradients(model, optimizer, max_norm=1.0)
563
+
473
564
"""
474
565
if clip_val is not None and max_norm is not None :
475
566
raise ValueError (
@@ -643,24 +734,37 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Abstr
643
734
r"""Skip gradient synchronization during backward to avoid redundant communication overhead.
644
735
645
736
Use this context manager when performing gradient accumulation to speed up training with multiple devices.
646
-
647
- Example::
648
-
649
- # Accumulate gradient 8 batches at a time
650
- with fabric.no_backward_sync(model, enabled=(batch_idx % 8 != 0)):
651
- output = model(input)
652
- loss = ...
653
- fabric.backward(loss)
654
- ...
655
-
656
- For those strategies that don't support it, a warning is emitted. For single-device strategies, it is a no-op.
657
737
Both the model's ``.forward()`` and the ``fabric.backward()`` call need to run under this context.
658
738
659
739
Args:
660
- module: The module for which to control the gradient synchronization.
740
+ module: The module for which to control the gradient synchronization. Must be a module that was
741
+ set up with :meth:`setup` or :meth:`setup_module`.
661
742
enabled: Whether the context manager is enabled or not. ``True`` means skip the sync, ``False`` means do not
662
743
skip.
663
744
745
+ Returns:
746
+ A context manager that controls gradient synchronization.
747
+
748
+ Raises:
749
+ TypeError: If the module was not set up with Fabric first.
750
+
751
+ Note:
752
+ For strategies that don't support gradient sync control, a warning is emitted and the context manager
753
+ becomes a no-op. For single-device strategies, it is always a no-op.
754
+
755
+ Example::
756
+
757
+ # Accumulate gradients over 8 batches
758
+ for batch_idx, batch in enumerate(dataloader):
759
+ with fabric.no_backward_sync(model, enabled=(batch_idx % 8 != 0)):
760
+ output = model(batch)
761
+ loss = criterion(output, target)
762
+ fabric.backward(loss)
763
+
764
+ if batch_idx % 8 == 0:
765
+ optimizer.step()
766
+ optimizer.zero_grad()
767
+
664
768
"""
665
769
module , _ = _unwrap_compiled (module )
666
770
if not isinstance (module , _FabricModule ):
@@ -726,13 +830,28 @@ def save(
726
830
This method must be called on all processes!
727
831
728
832
Args:
729
- path: A path to where the file(s) should be saved
833
+ path: A path to where the file(s) should be saved.
730
834
state: A dictionary with contents to be saved. If the dict contains modules or optimizers, their
731
835
state-dict will be retrieved and converted automatically.
732
836
filter: An optional dictionary containing filter callables that return a boolean indicating whether the
733
837
given item should be saved (``True``) or filtered out (``False``). Each filter key should match a
734
838
state key, where its filter will be applied to the ``state_dict`` generated.
735
839
840
+ Raises:
841
+ TypeError: If filter is not a dictionary or contains non-callable values.
842
+ ValueError: If filter keys don't match state keys.
843
+
844
+ Example::
845
+
846
+ state = {"model": model, "optimizer": optimizer, "epoch": epoch}
847
+ fabric.save("checkpoint.pth", state)
848
+
849
+ # With filter
850
+ def param_filter(name, param):
851
+ return "bias" not in name # Save only non-bias parameters
852
+
853
+ fabric.save("checkpoint.pth", state, filter={"model": param_filter})
854
+
736
855
"""
737
856
if filter is not None :
738
857
if not isinstance (filter , dict ):
@@ -759,7 +878,7 @@ def load(
759
878
This method must be called on all processes!
760
879
761
880
Args:
762
- path: A path to where the file is located
881
+ path: A path to where the file is located.
763
882
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
764
883
If no state is given, then the checkpoint will be returned in full.
765
884
strict: Whether to enforce that the keys in `state` match the keys in the checkpoint.
@@ -768,6 +887,16 @@ def load(
768
887
The remaining items that were not restored into the given state dictionary. If no state dictionary is
769
888
given, the full checkpoint will be returned.
770
889
890
+ Example::
891
+
892
+ # Load full checkpoint
893
+ checkpoint = fabric.load("checkpoint.pth")
894
+
895
+ # Load into existing objects
896
+ state = {"model": model, "optimizer": optimizer}
897
+ remainder = fabric.load("checkpoint.pth", state)
898
+ epoch = remainder.get("epoch", 0)
899
+
771
900
"""
772
901
unwrapped_state = _unwrap_objects (state )
773
902
remainder = self ._strategy .load_checkpoint (path = path , state = unwrapped_state , strict = strict )
@@ -805,18 +934,32 @@ def launch(self, function: Callable[["Fabric"], Any] = _do_nothing, *args: Any,
805
934
Args:
806
935
function: Optional function to launch when using a spawn/fork-based strategy, for example, when using the
807
936
XLA strategy (``accelerator="tpu"``). The function must accept at least one argument, to which
808
- the Fabric object itself will be passed.
937
+ the Fabric object itself will be passed. If not provided, only process initialization will be performed.
809
938
*args: Optional positional arguments to be passed to the function.
810
939
**kwargs: Optional keyword arguments to be passed to the function.
811
940
812
941
Returns:
813
942
Returns the output of the function that ran in worker process with rank 0.
814
943
815
- The ``launch()`` method should only be used if you intend to specify accelerator, devices, and so on in
816
- the code (programmatically). If you are launching with the Lightning CLI, ``fabric run ...``, remove
817
- ``launch()`` from your code.
944
+ Raises:
945
+ RuntimeError: If called when script was launched through the CLI.
946
+ TypeError: If function is provided but not callable, or if function doesn't accept required arguments.
947
+
948
+ Note:
949
+ The ``launch()`` method should only be used if you intend to specify accelerator, devices, and so on in
950
+ the code (programmatically). If you are launching with the Lightning CLI, ``fabric run ...``, remove
951
+ ``launch()`` from your code.
952
+
953
+ The ``launch()`` is a no-op when called multiple times and no function is passed in.
954
+
955
+ Example::
956
+
957
+ def train_function(fabric):
958
+ model, optimizer = fabric.setup(model, optimizer)
959
+ # ... training code ...
818
960
819
- The ``launch()`` is a no-op when called multiple times and no function is passed in.
961
+ fabric = Fabric(accelerator="tpu", devices=8)
962
+ fabric.launch(train_function)
820
963
821
964
"""
822
965
if _is_using_cli ():
0 commit comments