Skip to content

Commit 13f918c

Browse files
committed
Run formatter
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 7959a62 commit 13f918c

File tree

27 files changed

+91
-108
lines changed

27 files changed

+91
-108
lines changed

.meta/mast/hydrate_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
python .meta/mast/hydrate_cache.py --model-id Qwen/Qwen3-32B
1515
1616
"""
17+
1718
import argparse
1819
import os
1920
import sys

apps/grpo/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def simple_grpo_loss(
140140

141141
@dataclass
142142
class RewardActor(ForgeActor):
143-
144143
reward_functions: list[Callable]
145144

146145
@endpoint

apps/sft/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,6 @@ def __repr__(self) -> str:
287287

288288

289289
async def run(cfg: DictConfig) -> None:
290-
291290
logging.info("Spawning recipe...")
292291
process_cfg = cfg.pop("processes")
293292

src/forge/actors/coder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
import tempfile
1111
from pathlib import Path
1212

13-
from monarch.actor import endpoint
14-
1513
from forge.controller import ForgeActor
1614

15+
from monarch.actor import endpoint
16+
1717
logger = logging.getLogger(__name__)
1818
logger.setLevel(logging.DEBUG)
1919

src/forge/actors/generator.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,28 @@
1717

1818
import torch
1919
import torchstore as ts
20+
21+
from forge.actors._torchstore_utils import (
22+
extract_param_name,
23+
get_dcp_whole_state_dict_key,
24+
get_param_key,
25+
get_param_prefix,
26+
load_tensor_from_dcp,
27+
rdma_available,
28+
)
29+
30+
from forge.controller import (
31+
ForgeActor,
32+
get_proc_mesh,
33+
host_mesh_from_proc,
34+
stop_proc_mesh,
35+
)
36+
from forge.data_models.completion import Completion
37+
from forge.data_models.prompt import to_prompt
38+
from forge.observability.metrics import record_metric, Reduce
39+
from forge.observability.perf_tracker import Tracer
40+
from forge.types import ProcessConfig
41+
from forge.util._shared_tensor import SharedTensor, SharedTensorHandle
2042
from monarch.actor import current_rank, endpoint, ProcMesh, this_host
2143

2244
from vllm.config import VllmConfig
@@ -42,28 +64,6 @@
4264
from vllm.v1.structured_output import StructuredOutputManager
4365
from vllm.worker.worker_base import WorkerWrapperBase
4466

45-
from forge.actors._torchstore_utils import (
46-
extract_param_name,
47-
get_dcp_whole_state_dict_key,
48-
get_param_key,
49-
get_param_prefix,
50-
load_tensor_from_dcp,
51-
rdma_available,
52-
)
53-
54-
from forge.controller import (
55-
ForgeActor,
56-
get_proc_mesh,
57-
host_mesh_from_proc,
58-
stop_proc_mesh,
59-
)
60-
from forge.data_models.completion import Completion
61-
from forge.data_models.prompt import to_prompt
62-
from forge.observability.metrics import record_metric, Reduce
63-
from forge.observability.perf_tracker import Tracer
64-
from forge.types import ProcessConfig
65-
from forge.util._shared_tensor import SharedTensor, SharedTensorHandle
66-
6767
logger = logging.getLogger(__name__)
6868
logger.setLevel(logging.INFO)
6969

src/forge/actors/reference_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from dataclasses import dataclass, field, fields
1414

1515
import torch
16+
17+
from forge.controller import ForgeActor
18+
from forge.observability.metrics import record_metric, Reduce
19+
from forge.observability.perf_tracker import Tracer
20+
from forge.util.ops import compute_logprobs
1621
from monarch.actor import current_rank, current_size, endpoint
1722
from torch.distributed.tensor import DTensor
1823

@@ -27,11 +32,6 @@
2732
from torchtitan.experiments.forge.engine import ForgeEngine
2833
from torchtitan.experiments.forge.job_config import ForgeJobConfig
2934

30-
from forge.controller import ForgeActor
31-
from forge.observability.metrics import record_metric, Reduce
32-
from forge.observability.perf_tracker import Tracer
33-
from forge.util.ops import compute_logprobs
34-
3535
logger = logging.getLogger(__name__)
3636
logger.setLevel(logging.INFO)
3737

src/forge/actors/replay_buffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
from operator import itemgetter
1212
from typing import Any, Callable
1313

14-
from monarch.actor import endpoint
15-
1614
from forge.controller import ForgeActor
1715
from forge.observability.metrics import record_metric, Reduce
1816
from forge.observability.perf_tracker import trace
1917

18+
from monarch.actor import endpoint
19+
2020
logger = logging.getLogger(__name__)
2121
logger.setLevel(logging.INFO)
2222

src/forge/actors/trainer/titan.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@
1616
import torch.distributed.checkpoint as dcp
1717
import torchstore as ts
1818

19+
from forge.actors._torchstore_utils import (
20+
DcpHandle,
21+
get_dcp_whole_state_dict_key,
22+
get_param_key,
23+
rdma_available,
24+
)
25+
26+
from forge.controller import ForgeActor
27+
from forge.data.utils import batch_to_device
28+
from forge.observability.metrics import record_metric, Reduce
29+
from forge.observability.perf_tracker import Tracer
30+
1931
from monarch.actor import endpoint
2032
from torch import Tensor
2133
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
@@ -36,18 +48,6 @@
3648
from torchtitan.experiments.forge.engine import ForgeEngine
3749
from torchtitan.experiments.forge.job_config import ForgeJobConfig
3850

39-
from forge.actors._torchstore_utils import (
40-
DcpHandle,
41-
get_dcp_whole_state_dict_key,
42-
get_param_key,
43-
rdma_available,
44-
)
45-
46-
from forge.controller import ForgeActor
47-
from forge.data.utils import batch_to_device
48-
from forge.observability.metrics import record_metric, Reduce
49-
from forge.observability.perf_tracker import Tracer
50-
5151
logger = logging.getLogger(__name__)
5252
logger.setLevel(logging.DEBUG)
5353

src/forge/controller/launcher.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import monarch
1818
import torchx.specs as specs
19+
20+
from forge.types import Launcher, LauncherConfig
1921
from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints
2022
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
2123
from monarch._rust_bindings.monarch_hyperactor.config import configure
@@ -25,8 +27,6 @@
2527
from monarch.tools.commands import create, info
2628
from monarch.tools.config import Config, Workspace
2729

28-
from forge.types import Launcher, LauncherConfig
29-
3030
_MAST_AVAILABLE = False
3131

3232
try:
@@ -269,7 +269,6 @@ def add_additional_packages(self, packages: "Packages") -> "Packages":
269269
return packages
270270

271271
def build_appdef(self) -> specs.AppDef:
272-
273272
# create the app definition for the worker
274273
remote_end_python_path = ":".join(
275274
[

src/forge/controller/provisioner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
"""Remote and local resource manager for allocation and provisioning."""
8+
89
import asyncio
910
import logging
1011

@@ -14,6 +15,10 @@
1415

1516
import torch
1617

18+
from forge.controller.launcher import BaseLauncher, get_launcher
19+
from forge.env import all_env_vars, FORGE_DISABLE_METRICS
20+
from forge.types import ProcessConfig, ProvisionerConfig
21+
1722
from monarch._src.actor.actor_mesh import ActorMesh
1823
from monarch._src.actor.shape import Extent
1924

@@ -22,10 +27,6 @@
2227
from monarch.tools import commands
2328
from monarch.utils import setup_env_for_distributed
2429

25-
from forge.controller.launcher import BaseLauncher, get_launcher
26-
from forge.env import all_env_vars, FORGE_DISABLE_METRICS
27-
from forge.types import ProcessConfig, ProvisionerConfig
28-
2930
logger = logging.getLogger(__name__)
3031
logger.setLevel(logging.DEBUG)
3132

@@ -586,7 +587,6 @@ async def shutdown_metric_logger():
586587

587588

588589
async def shutdown():
589-
590590
await shutdown_metric_logger()
591591

592592
logger.info("Shutting down provisioner..")

0 commit comments

Comments
 (0)