Skip to content

Commit 486f07b

Browse files
kaushikb11rohitgr7
andauthored
Add notes to Trainer docs when devices flag is not defined (#12155)
Co-authored-by: Rohit Gupta <[email protected]>
1 parent f4a0069 commit 486f07b

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

docs/source/common/trainer.rst

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,26 @@ Example::
242242

243243
Trainer(accelerator=MyOwnAcc())
244244

245+
.. note::
246+
247+
If the ``devices`` flag is not defined, it will assume ``devices`` to be ``"auto"`` and fetch the ``auto_device_count``
248+
from the accelerator.
249+
250+
.. code-block:: python
251+
252+
# This is part of the built-in `GPUAccelerator`
253+
class GPUAccelerator(Accelerator):
254+
"""Accelerator for GPU devices."""
255+
256+
@staticmethod
257+
def auto_device_count() -> int:
258+
"""Get the devices when set to auto."""
259+
return torch.cuda.device_count()
260+
261+
262+
# Training with GPU Accelerator using total number of gpus available on the system
263+
Trainer(accelerator="gpu")
264+
245265
.. warning:: Passing training strategies (e.g., ``"ddp"``) to ``accelerator`` has been deprecated in v1.5.0
246266
and will be removed in v1.7.0. Please use the ``strategy`` argument instead.
247267

@@ -580,6 +600,26 @@ based on the accelerator type (``"cpu", "gpu", "tpu", "ipu", "auto"``).
580600
# Training with IPU Accelerator using 4 ipus
581601
trainer = Trainer(devices="auto", accelerator="ipu")
582602
603+
.. note::
604+
605+
If the ``devices`` flag is not defined, it will assume ``devices`` to be ``"auto"`` and fetch the ``auto_device_count``
606+
from the accelerator.
607+
608+
.. code-block:: python
609+
610+
# This is part of the built-in `GPUAccelerator`
611+
class GPUAccelerator(Accelerator):
612+
"""Accelerator for GPU devices."""
613+
614+
@staticmethod
615+
def auto_device_count() -> int:
616+
"""Get the devices when set to auto."""
617+
return torch.cuda.device_count()
618+
619+
620+
# Training with GPU Accelerator using total number of gpus available on the system
621+
Trainer(accelerator="gpu")
622+
583623
enable_checkpointing
584624
^^^^^^^^^^^^^^^^^^^^
585625

0 commit comments

Comments
 (0)