Skip to content

Commit ae3ae6b

Browse files
committed
enhance documentation
1 parent baf3e5c commit ae3ae6b

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

docs/source-pytorch/extensions/accelerator.rst

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,45 +36,79 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc
3636

3737
.. code-block:: python
3838
39+
import torch
3940
import xpulib
4041
42+
from functools import lru_cache
43+
from typing import Any, Dict, Union
44+
from lightning.pytorch.accelerators.accelerator import Accelerator
45+
46+
from typing_extensions import override
47+
4148
4249
class XPUAccelerator(Accelerator):
4350
"""Support for a hypothetical XPU, optimized for large-scale machine learning."""
4451
52+
@override
53+
def setup_device(self, device: torch.device) -> None:
54+
"""
55+
Raises:
56+
ValueError:
57+
If the selected device is not of type hypothetical XPU.
58+
"""
59+
if device.type != "xpu":
60+
raise ValueError(f"Device should be of type 'xpu', got '{device.type}' instead.")
61+
if device.index is None:
62+
device = torch.device("xpu", 0)
63+
xpulib.set_device(device.index)
64+
65+
@override
66+
def teardown(self) -> None:
67+
xpulib.empty_cache()
68+
4569
@staticmethod
70+
@override
4671
def parse_devices(devices: Any) -> Any:
4772
# Put parsing logic here how devices can be passed into the Trainer
4873
# via the `devices` argument
4974
return devices
5075
5176
@staticmethod
77+
@override
5278
def get_parallel_devices(devices: Any) -> Any:
5379
# Here, convert the device indices to actual device objects
5480
return [torch.device("xpu", idx) for idx in devices]
5581
5682
@staticmethod
83+
@override
5784
def auto_device_count() -> int:
5885
# Return a value for auto-device selection when `Trainer(devices="auto")`
5986
return xpulib.available_devices()
6087
6188
@staticmethod
89+
@override
6290
def is_available() -> bool:
6391
return xpulib.is_available()
6492
6593
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
6694
# Return optional device statistics for loggers
6795
return {}
6896
97+
@staticmethod
98+
@override
99+
def get_device() -> str:
100+
return "xpu"
101+
69102
70103
Finally, add the XPUAccelerator to the Trainer:
71104

72105
.. code-block:: python
73106
74107
from lightning.pytorch import Trainer
75-
108+
from lightning.pytorch.strategies import DDPStrategy
76109
accelerator = XPUAccelerator()
77-
trainer = Trainer(accelerator=accelerator, devices=2)
110+
strategy = DDPStrategy(parallel_devices=accelerator.get_parallel_devices(2))
111+
trainer = Trainer(accelerator=accelerator, strategy=strategy, devices=2)
78112
79113
80114
:doc:`Learn more about Strategies <../extensions/strategy>` and how they interact with the Accelerator.
@@ -93,6 +127,7 @@ If you wish to switch to a custom accelerator from the CLI without code changes,
93127
...
94128
95129
@classmethod
130+
@override
96131
def register_accelerators(cls, accelerator_registry):
97132
accelerator_registry.register(
98133
"xpu",

0 commit comments

Comments
 (0)