Skip to content

Commit 6bf6540

Browse files
Update Lightning Lite docs (4/n) (#16246)
Co-authored-by: edenlightning <[email protected]>
1 parent 6092de9 commit 6bf6540

File tree

2 files changed

+57
-39
lines changed

2 files changed

+57
-39
lines changed

docs/source-pytorch/fabric/fabric.rst

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,7 @@ As you can see, this function accepts one argument, the ``Fabric`` object, and i
205205
Fabric Flags
206206
************
207207

208-
Fabric is specialized in accelerated distributed training and inference. It offers you convenient ways to configure
209-
your device and communication strategy and to switch seamlessly from one to the other. The terminology and usage are
210-
identical to Lightning, which means minimum effort for you to convert when you decide to do so.
208+
Fabric is designed to accelerate distributed training and inference. It makes it easy to configure your device and communication strategy, and to switch seamlessly from one to the other.
211209

212210

213211
accelerator
@@ -292,41 +290,6 @@ Configure the devices to run on. Can be of type:
292290
fabric = Fabric(devices="-1", accelerator="gpu") # equivalent
293291
294292
295-
296-
gpus
297-
====
298-
299-
.. warning:: ``gpus=x`` has been deprecated in v1.7 and will be removed in v2.0.
300-
Please use ``accelerator='gpu'`` and ``devices=x`` instead.
301-
302-
Shorthand for setting ``devices=X`` and ``accelerator="gpu"``.
303-
304-
.. code-block:: python
305-
306-
# Run on two GPUs
307-
fabric = Fabric(accelerator="gpu", devices=2)
308-
309-
# Equivalent
310-
fabric = Fabric(devices=2, accelerator="gpu")
311-
312-
313-
tpu_cores
314-
=========
315-
316-
.. warning:: ``tpu_cores=x`` has been deprecated in v1.7 and will be removed in v2.0.
317-
Please use ``accelerator='tpu'`` and ``devices=x`` instead.
318-
319-
Shorthand for ``devices=X`` and ``accelerator="tpu"``.
320-
321-
.. code-block:: python
322-
323-
# Run on eight TPUs
324-
fabric = Fabric(accelerator="tpu", devices=8)
325-
326-
# Equivalent
327-
fabric = Fabric(devices=8, accelerator="tpu")
328-
329-
330293
num_nodes
331294
=========
332295

@@ -395,6 +358,33 @@ To define your own behavior, subclass the relevant class and pass it in. Here's
395358
fabric = Fabric(plugins=[MyCluster()], ...)
396359
397360
361+
callbacks
362+
=========
363+
364+
A callback class is a collection of methods that the training loop can call at a specific point in time, for example, at the end of an epoch.
365+
Add callbacks to Fabric to inject logic into your training loop from an external callback class.
366+
367+
.. code-block:: python
368+
369+
class MyCallback:
370+
def on_train_epoch_end(self, results):
371+
...
372+
373+
You can then register this callback, or multiple ones directly in Fabric:
374+
375+
.. code-block:: python
376+
377+
fabric = Fabric(callbacks=[MyCallback()])
378+
379+
380+
Then, in your training loop, you can call a hook by its name. Any callback objects that have this hook will execute it:
381+
382+
.. code-block:: python
383+
384+
# Call any hook by name
385+
fabric.call("on_train_epoch_end", results={...})
386+
387+
398388
----------
399389

400390

@@ -595,3 +585,31 @@ For single-device strategies, it is a no-op. There are strategies that don't sup
595585
- xla
596586

597587
For these, the context manager falls back to a no-op and emits a warning.
588+
589+
590+
call
591+
====
592+
593+
Use this to run all registered callback hooks with a given name and inputs.
594+
It is useful when building a Trainer that allows the user to run arbitrary code at fixed points in the training loop.
595+
596+
.. code-block:: python
597+
598+
class MyCallback:
599+
def on_train_start(self):
600+
...
601+
602+
def on_train_epoch_end(self, model, results):
603+
...
604+
605+
606+
fabric = Fabric(callbacks=[MyCallback()])
607+
608+
# Call any hook by name
609+
fabric.call("on_train_start")
610+
611+
# Pass in additional arguments that the hook requires
612+
fabric.call("on_train_epoch_end", model=..., results={...})
613+
614+
# Only the callbacks that have this method defined will be executed
615+
fabric.call("undefined")

src/lightning_fabric/fabric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ class MyCallback:
552552
def on_train_epoch_end(self, results):
553553
...
554554
555-
fabric = Fabric(callbacks=[MyCallback]))
555+
fabric = Fabric(callbacks=[MyCallback()])
556556
fabric.call("on_train_epoch_end", results={...})
557557
"""
558558
for callback in self._callbacks:

0 commit comments

Comments
 (0)