@@ -36,45 +36,79 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc
3636
3737.. code-block :: python
3838
39+ import torch
3940 import xpulib
4041
42+ from functools import lru_cache
43+ from typing import Any, Dict, Union
44+ from lightning.pytorch.accelerators.accelerator import Accelerator
45+
46+ from typing_extensions import override
47+
4148
4249 class XPUAccelerator (Accelerator ):
4350 """ Support for a hypothetical XPU, optimized for large-scale machine learning."""
4451
52+ @override
53+ def setup_device (self , device : torch.device) -> None :
54+ """
55+ Raises:
56+ ValueError:
57+ If the selected device is not of type hypothetical XPU.
58+ """
59+ if device.type != " xpu" :
60+ raise ValueError (f " Device should be of type 'xpu', got ' { device.type} ' instead. " )
61+ if device.index is None :
62+ device = torch.device(" xpu" , 0 )
63+ xpulib.set_device(device.index)
64+
65+ @override
66+ def teardown (self ) -> None :
67+ xpulib.empty_cache()
68+
4569 @ staticmethod
70+ @override
4671 def parse_devices (devices : Any) -> Any:
4772 # Put parsing logic here how devices can be passed into the Trainer
4873 # via the `devices` argument
4974 return devices
5075
5176 @ staticmethod
77+ @override
5278 def get_parallel_devices (devices : Any) -> Any:
5379 # Here, convert the device indices to actual device objects
5480 return [torch.device(" xpu" , idx) for idx in devices]
5581
5682 @ staticmethod
83+ @override
5784 def auto_device_count () -> int :
5885 # Return a value for auto-device selection when `Trainer(devices="auto")`
5986 return xpulib.available_devices()
6087
6188 @ staticmethod
89+ @override
6290 def is_available () -> bool :
6391 return xpulib.is_available()
6492
6593 def get_device_stats (self , device : Union[str , torch.device]) -> Dict[str , Any]:
6694 # Return optional device statistics for loggers
6795 return {}
6896
97+ @ staticmethod
98+ @override
99+ def get_device () -> str :
100+ return " xpu"
101+
69102
70103 Finally, add the XPUAccelerator to the Trainer:
71104
72105.. code-block :: python
73106
74107 from lightning.pytorch import Trainer
75-
108+ from lightning.pytorch.strategies import DDPStrategy
76109 accelerator = XPUAccelerator()
77- trainer = Trainer(accelerator = accelerator, devices = 2 )
110+ strategy = DDPStrategy(parallel_devices = accelerator.get_parallel_devices(2 ))
111+ trainer = Trainer(accelerator = accelerator, strategy = strategy, devices = 2 )
78112
79113
80114:doc: `Learn more about Strategies <../extensions/strategy >` and how they interact with the Accelerator.
@@ -93,6 +127,7 @@ If you wish to switch to a custom accelerator from the CLI without code changes,
93127 ...
94128
95129 @ classmethod
130+ @override
96131 def register_accelerators (cls , accelerator_registry ):
97132 accelerator_registry.register(
98133 " xpu" ,
0 commit comments