|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +from collections.abc import Callable |
14 | 15 | from contextlib import AbstractContextManager |
15 | | -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union |
| 16 | +from typing import TYPE_CHECKING, Any, Literal, Optional |
16 | 17 |
|
17 | 18 | import torch |
18 | 19 | from torch import Tensor |
@@ -188,7 +189,7 @@ def setup_module(self, module: Module) -> Module: |
188 | 189 | def module_to_device(self, module: Module) -> None: |
189 | 190 | return self.xla_fsdp_impl.module_to_device(module=module) |
190 | 191 |
|
191 | | - def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: |
| 192 | + def module_init_context(self, empty_init: bool | None = None) -> AbstractContextManager: |
192 | 193 | return self.xla_fsdp_impl.module_init_context(empty_init=empty_init) |
193 | 194 |
|
194 | 195 | @override |
@@ -251,7 +252,7 @@ def all_reduce( |
251 | 252 | return self.xla_fsdp_impl.all_reduce(output=output, group=group, reduce_op=reduce_op) |
252 | 253 |
|
253 | 254 | @override |
254 | | - def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: |
| 255 | + def barrier(self, name: str | None = None, *args: Any, **kwargs: Any) -> None: |
255 | 256 | return self.xla_fsdp_impl.barrier(name=name, *args, **kwargs) |
256 | 257 |
|
257 | 258 | @override |
|
0 commit comments