Skip to content

Commit 156452b

Browse files
committed
fix optional imports
1 parent 283c771 commit 156452b

File tree

1 file changed

+3
-8
lines changed
  • src/lightning/pytorch/strategies

1 file changed

+3
-8
lines changed

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,7 @@
1717
from contextlib import contextmanager, nullcontext
1818
from datetime import timedelta
1919
from pathlib import Path
20-
from typing import (
21-
TYPE_CHECKING,
22-
Any,
23-
Literal,
24-
Optional,
25-
)
20+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
2621

2722
import torch
2823
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
@@ -151,14 +146,14 @@ def __init__(
151146
precision_plugin: Precision | None = None,
152147
process_group_backend: str | None = None,
153148
timeout: timedelta | None = default_pg_timeout,
154-
cpu_offload: bool | "CPUOffload" | None = None,
149+
cpu_offload: Union[bool, "CPUOffload"] | None = None,
155150
mixed_precision: Optional["MixedPrecision"] = None,
156151
auto_wrap_policy: Optional["_POLICY"] = None,
157152
activation_checkpointing: type[Module] | list[type[Module]] | None = None,
158153
activation_checkpointing_policy: Optional["_POLICY"] = None,
159154
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
160155
state_dict_type: Literal["full", "sharded"] = "full",
161-
device_mesh: tuple[int] | "DeviceMesh" | None = None,
156+
device_mesh: Union[tuple[int], "DeviceMesh"] | None = None,
162157
**kwargs: Any,
163158
) -> None:
164159
super().__init__(

0 commit comments

Comments
 (0)