diff --git a/src/lightning/pytorch/accelerators/__init__.py b/src/lightning/pytorch/accelerators/__init__.py index 4cadee51f64c7..d7c2197aa5ed4 100644 --- a/src/lightning/pytorch/accelerators/__init__.py +++ b/src/lightning/pytorch/accelerators/__init__.py @@ -10,16 +10,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +__all__ = [ + "Accelerator", + "CPUAccelerator", + "CUDAAccelerator", + "MPSAccelerator", + "XLAAccelerator", + "find_usable_cuda_devices", +] + import sys -from lightning.fabric.accelerators import find_usable_cuda_devices # noqa: F401 +from lightning.fabric.accelerators import find_usable_cuda_devices from lightning.fabric.accelerators.registry import _AcceleratorRegistry from lightning.fabric.utilities.registry import _register_classes from lightning.pytorch.accelerators.accelerator import Accelerator -from lightning.pytorch.accelerators.cpu import CPUAccelerator # noqa: F401 -from lightning.pytorch.accelerators.cuda import CUDAAccelerator # noqa: F401 -from lightning.pytorch.accelerators.mps import MPSAccelerator # noqa: F401 -from lightning.pytorch.accelerators.xla import XLAAccelerator # noqa: F401 +from lightning.pytorch.accelerators.cpu import CPUAccelerator +from lightning.pytorch.accelerators.cuda import CUDAAccelerator +from lightning.pytorch.accelerators.mps import MPSAccelerator +from lightning.pytorch.accelerators.xla import XLAAccelerator AcceleratorRegistry = _AcceleratorRegistry() _register_classes(AcceleratorRegistry, "register_accelerators", sys.modules[__name__], Accelerator)