|
17 | 17 | from contextlib import contextmanager, nullcontext |
18 | 18 | from datetime import timedelta |
19 | 19 | from pathlib import Path |
20 | | -from typing import ( |
21 | | - TYPE_CHECKING, |
22 | | - Any, |
23 | | - Literal, |
24 | | - Optional, |
25 | | -) |
| 20 | +from typing import TYPE_CHECKING, Any, Literal, Optional, Union |
26 | 21 |
|
27 | 22 | import torch |
28 | 23 | from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only |
@@ -151,14 +146,14 @@ def __init__( |
151 | 146 | precision_plugin: Precision | None = None, |
152 | 147 | process_group_backend: str | None = None, |
153 | 148 | timeout: timedelta | None = default_pg_timeout, |
154 | | - cpu_offload: bool | "CPUOffload" | None = None, |
| 149 | + cpu_offload: Union[bool, "CPUOffload"] | None = None, |
155 | 150 | mixed_precision: Optional["MixedPrecision"] = None, |
156 | 151 | auto_wrap_policy: Optional["_POLICY"] = None, |
157 | 152 | activation_checkpointing: type[Module] | list[type[Module]] | None = None, |
158 | 153 | activation_checkpointing_policy: Optional["_POLICY"] = None, |
159 | 154 | sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD", |
160 | 155 | state_dict_type: Literal["full", "sharded"] = "full", |
161 | | - device_mesh: tuple[int] | "DeviceMesh" | None = None, |
| 156 | + device_mesh: Union[tuple[int], "DeviceMesh"] | None = None, |
162 | 157 | **kwargs: Any, |
163 | 158 | ) -> None: |
164 | 159 | super().__init__( |
|
0 commit comments