Skip to content

Commit 4dc6ffc

Browse files
committed
Fix isort indentation
1 parent 8108232 commit 4dc6ffc

File tree

5 files changed

+28
-35
lines changed

5 files changed

+28
-35
lines changed

3rdparty/flash-linear-attention

flame/models/parallelize_fla.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,18 @@
1212
import torch
1313
import torch.nn as nn
1414
from torch.distributed import DeviceMesh
15-
from torch.distributed._composable.fsdp import (CPUOffloadPolicy,
16-
MixedPrecisionPolicy,
17-
fully_shard)
15+
from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
1816
from torch.distributed._composable.replicate import replicate
1917
from torch.distributed._tensor import Replicate, Shard
20-
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
21-
checkpoint_wrapper as ptd_checkpoint_wrapper
22-
from torch.distributed.tensor.parallel import (ColwiseParallel,
23-
PrepareModuleInput,
24-
PrepareModuleOutput,
25-
RowwiseParallel,
26-
SequenceParallel,
27-
parallelize_module)
18+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
19+
from torch.distributed.tensor.parallel import (
20+
ColwiseParallel,
21+
PrepareModuleInput,
22+
PrepareModuleOutput,
23+
RowwiseParallel,
24+
SequenceParallel,
25+
parallelize_module
26+
)
2827

2928
from fla.modules.fused_linear_cross_entropy import LinearLossParallel
3029
from fla.modules.mlp import SwiGLULinearParallel
@@ -126,8 +125,10 @@ def __init__(
126125
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
127126
try:
128127
from torchao.float8.float8_tensor_parallel import (
129-
Float8ColwiseParallel, Float8RowwiseParallel,
130-
PrepareFloat8ModuleInput)
128+
Float8ColwiseParallel,
129+
Float8RowwiseParallel,
130+
PrepareFloat8ModuleInput
131+
)
131132
except ImportError:
132133
Float8ColwiseParallel = None
133134
Float8RowwiseParallel = None
@@ -268,8 +269,7 @@ def apply_tp(
268269
)
269270

270271
if enable_async_tp:
271-
from torch.distributed._symmetric_memory import \
272-
enable_symm_mem_for_group
272+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
273273

274274
torch._inductor.config._micro_pipeline_tp = True
275275
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
@@ -312,8 +312,7 @@ def _apply_ac_to_block(module: nn.Module, ac_config):
312312
f"Valid options: 'op' or a positive int representing layer frequency"
313313
)
314314
if use_op_sac:
315-
from torch.utils.checkpoint import (
316-
CheckpointPolicy, create_selective_checkpoint_contexts)
315+
from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
317316

318317
def _get_custom_policy(meta):
319318
def _custom_policy(ctx, func, *args, **kwargs):

flame/models/pipeline_fla.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,13 @@
1313
import torch.nn as nn
1414
from torch.distributed import DeviceMesh
1515
from torch.distributed.pipelining import PipelineStage
16-
from torch.distributed.pipelining.schedules import (ScheduleZBVZeroBubble,
17-
_PipelineSchedule,
18-
get_schedule_class)
16+
from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class
1917
from transformers import PretrainedConfig
2018

21-
from flame.models.parallelize_fla import (get_blocks, get_components_name,
22-
get_model)
19+
from flame.models.parallelize_fla import get_blocks, get_components_name, get_model
2320
from torchtitan.config_manager import JobConfig
2421
from torchtitan.distributed.parallel_dims import ParallelDims
25-
from torchtitan.distributed.pipeline import (build_pipeline_schedule,
26-
generate_split_points,
27-
stage_ids_this_rank)
22+
from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank
2823
from torchtitan.tools.logging import logger
2924

3025
DeviceType = Union[int, str, torch.device]

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,7 @@ Homepage = "https://github.com/fla-org/flame"
3737
[build-system]
3838
requires = ["setuptools>=45", "wheel", "ninja", "torch"]
3939
build-backend = "setuptools.build_meta"
40+
41+
[tool.isort]
42+
line_length = 127
43+
multi_line_output = 3

train.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,15 @@
2626
from torchtitan.components.checkpoint import CheckpointManager
2727
from torchtitan.components.ft import FTParallelDims, init_ft_manager
2828
from torchtitan.components.loss import cross_entropy_loss
29-
from torchtitan.components.metrics import (_build_metric_logger,
30-
build_device_memory_monitor,
31-
ensure_pp_loss_visible)
32-
from torchtitan.components.optimizer import (build_lr_schedulers,
33-
build_optimizers)
29+
from torchtitan.components.metrics import _build_metric_logger, build_device_memory_monitor, ensure_pp_loss_visible
30+
from torchtitan.components.optimizer import build_lr_schedulers, build_optimizers
3431
from torchtitan.distributed import ParallelDims
3532
from torchtitan.distributed import utils as dist_utils
3633
from torchtitan.protocols.model_converter import build_model_converters
37-
from torchtitan.protocols.train_spec import (TrainSpec, get_train_spec,
38-
register_train_spec)
34+
from torchtitan.protocols.train_spec import TrainSpec, get_train_spec, register_train_spec
3935
from torchtitan.tools import utils
4036
from torchtitan.tools.logging import init_logger, logger
41-
from torchtitan.tools.profiling import (maybe_enable_memory_snapshot,
42-
maybe_enable_profiling)
37+
from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
4338

4439
register_train_spec(
4540
TrainSpec(

0 commit comments

Comments
 (0)