Skip to content

Commit 7db8793

Browse files
committed
fix: fix type annotation.
1 parent a383d79 commit 7db8793

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

src/lightning/pytorch/accelerators/accelerator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import ABC
15-
from typing import Any, Optional
15+
from typing import Any
1616

1717
import lightning.pytorch as pl
1818
from lightning.fabric.accelerators.accelerator import Accelerator as _Accelerator
@@ -47,6 +47,6 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
4747
raise NotImplementedError
4848

4949
@classmethod
50-
def device_name(cls, device: Optional = None) -> str:
50+
def device_name(cls, device: _DEVICE = None) -> str:
5151
"""Get the device name for a given device."""
5252
return str(cls.is_available())

src/lightning/pytorch/accelerators/cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
114114
)
115115

116116
@classmethod
117-
def device_name(cls, device: Optional[torch.types.Device] = None) -> str:
117+
def device_name(cls, device: _DEVICE = None) -> str:
118118
if not cls.is_available():
119119
return "False"
120120
return torch.cuda.get_device_name(device)

src/lightning/pytorch/accelerators/mps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
8888
)
8989

9090
@classmethod
91-
def device_name(cls, device: Optional = None) -> str:
91+
def device_name(cls, device: _DEVICE = None) -> str:
9292
# todo: implement a better way to get the device name
9393
available = cls.is_available()
9494
gpu_type = " (mps)" if available else ""

src/lightning/pytorch/accelerators/xla.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Optional
14+
from typing import Any
1515

1616
from typing_extensions import override
1717

@@ -56,7 +56,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
5656
accelerator_registry.register("tpu", cls, description=cls.__name__)
5757

5858
@classmethod
59-
def device_name(cls, device: Optional = None) -> str:
59+
def device_name(cls, device: _DEVICE = None) -> str:
6060
is_available = cls.is_available()
6161
if not is_available:
6262
return str(is_available)

0 commit comments

Comments
 (0)