|
12 | 12 | import torch |
13 | 13 | import torch.nn as nn |
14 | 14 | from torch.distributed import DeviceMesh |
15 | | -from torch.distributed._composable.fsdp import (CPUOffloadPolicy, |
16 | | - MixedPrecisionPolicy, |
17 | | - fully_shard) |
| 15 | +from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard |
18 | 16 | from torch.distributed._composable.replicate import replicate |
19 | 17 | from torch.distributed._tensor import Replicate, Shard |
20 | | -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \ |
21 | | - checkpoint_wrapper as ptd_checkpoint_wrapper |
22 | | -from torch.distributed.tensor.parallel import (ColwiseParallel, |
23 | | - PrepareModuleInput, |
24 | | - PrepareModuleOutput, |
25 | | - RowwiseParallel, |
26 | | - SequenceParallel, |
27 | | - parallelize_module) |
| 18 | +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper |
| 19 | +from torch.distributed.tensor.parallel import ( |
| 20 | + ColwiseParallel, |
| 21 | + PrepareModuleInput, |
| 22 | + PrepareModuleOutput, |
| 23 | + RowwiseParallel, |
| 24 | + SequenceParallel, |
| 25 | + parallelize_module |
| 26 | +) |
28 | 27 |
|
29 | 28 | from fla.modules.fused_linear_cross_entropy import LinearLossParallel |
30 | 29 | from fla.modules.mlp import SwiGLULinearParallel |
@@ -126,8 +125,10 @@ def __init__( |
126 | 125 | # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there |
127 | 126 | try: |
128 | 127 | from torchao.float8.float8_tensor_parallel import ( |
129 | | - Float8ColwiseParallel, Float8RowwiseParallel, |
130 | | - PrepareFloat8ModuleInput) |
| 128 | + Float8ColwiseParallel, |
| 129 | + Float8RowwiseParallel, |
| 130 | + PrepareFloat8ModuleInput |
| 131 | + ) |
131 | 132 | except ImportError: |
132 | 133 | Float8ColwiseParallel = None |
133 | 134 | Float8RowwiseParallel = None |
@@ -268,8 +269,7 @@ def apply_tp( |
268 | 269 | ) |
269 | 270 |
|
270 | 271 | if enable_async_tp: |
271 | | - from torch.distributed._symmetric_memory import \ |
272 | | - enable_symm_mem_for_group |
| 272 | + from torch.distributed._symmetric_memory import enable_symm_mem_for_group |
273 | 273 |
|
274 | 274 | torch._inductor.config._micro_pipeline_tp = True |
275 | 275 | enable_symm_mem_for_group(tp_mesh.get_group().group_name) |
@@ -312,8 +312,7 @@ def _apply_ac_to_block(module: nn.Module, ac_config): |
312 | 312 | f"Valid options: 'op' or a positive int representing layer frequency" |
313 | 313 | ) |
314 | 314 | if use_op_sac: |
315 | | - from torch.utils.checkpoint import ( |
316 | | - CheckpointPolicy, create_selective_checkpoint_contexts) |
| 315 | + from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts |
317 | 316 |
|
318 | 317 | def _get_custom_policy(meta): |
319 | 318 | def _custom_policy(ctx, func, *args, **kwargs): |
|
0 commit comments