Skip to content

Commit 7a56365

Browse files
committed
run lint
1 parent 85a9aed commit 7a56365

File tree

11 files changed

+24
-23
lines changed

11 files changed

+24
-23
lines changed

apps/grpo/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121
from forge.data.rewards import MathReward, ThinkingReward
2222
from forge.services.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
23-
from forge.services.reference_actor import compute_sequence_logprobs, TitanRefModel
23+
from forge.services.reference_service import compute_sequence_logprobs, TitanRefModel
2424
from forge.services.replay_buffer import ReplayBuffer
2525
from forge.util.metric_logging import get_metric_logger
2626
from monarch.actor import endpoint

src/forge/controller/service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
import math
1010
import sys
1111

12+
from monarch.actor import Actor, current_rank, current_size, endpoint
13+
1214
from forge.controller.proc_mesh import get_proc_mesh, stop_proc_mesh
1315
from forge.types import ProcessConfig
1416

15-
from monarch.actor import Actor, current_rank, current_size, endpoint
16-
1717
logger = logging.getLogger(__name__)
1818
logger.setLevel(logging.DEBUG)
1919

src/forge/controller/service/replica.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
from enum import Enum
1414
from typing import Optional
1515

16+
from monarch.actor import ActorError
17+
1618
from forge.controller import Service
1719
from forge.types import ProcessConfig
1820

19-
from monarch.actor import ActorError
20-
2121
logger = logging.getLogger(__name__)
2222
logger.setLevel(logging.DEBUG)
2323

src/forge/controller/service/spawn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010

1111
from forge.controller import Service
1212

13+
from monarch.actor import proc_mesh
14+
1315
from forge.controller.service import Service, ServiceActor, ServiceConfig
1416

1517
from forge.controller.service.interface import ServiceInterface, ServiceInterfaceV2
16-
from monarch.actor import proc_mesh
1718

1819
logger = logging.getLogger(__name__)
1920
logger.setLevel(logging.INFO)

src/forge/interfaces.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from abc import ABC, abstractmethod
88
from typing import Any, Mapping
99

10+
from monarch.actor import endpoint
11+
1012
from forge.controller import Service
1113

1214
from forge.types import Action, Message, Observation, Scalar, State
1315

14-
from monarch.actor import endpoint
15-
1616

1717
class Transform(ABC):
1818
"""Abstract base class for observation transforms.

src/forge/services/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __getattr__(name):
2525

2626
return ReplayBuffer
2727
elif name == "TitanRefModel":
28-
from .reference_actor import TitanRefModel
28+
from .reference_service import TitanRefModel
2929

3030
return TitanRefModel
3131
else:

src/forge/services/collector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212

1313
from typing import Callable
1414

15+
from monarch.actor import Actor, endpoint
16+
1517
from forge.interfaces import Policy
1618

1719
from forge.services.replay_buffer import ReplayBuffer
1820

1921
from forge.types import Trajectory
2022

21-
from monarch.actor import Actor, endpoint
22-
2323

2424
class Collector(Actor):
2525
"""Collects trajectories for the training loop."""

src/forge/services/policy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,6 @@
1313
from typing import Dict, List
1414

1515
import torch
16-
17-
from forge.controller import get_proc_mesh, Service, stop_proc_mesh
18-
19-
from forge.data.sharding import VLLMSharding
20-
from forge.interfaces import Policy as PolicyInterface
21-
from forge.types import ProcessConfig
2216
from monarch.actor import current_rank, endpoint, ProcMesh
2317
from torchstore import MultiProcessStore
2418
from torchstore._state_dict_utils import DELIM
@@ -43,6 +37,12 @@
4337
from vllm.v1.structured_output import StructuredOutputManager
4438
from vllm.worker.worker_base import WorkerWrapperBase
4539

40+
from forge.controller import get_proc_mesh, Service, stop_proc_mesh
41+
42+
from forge.data.sharding import VLLMSharding
43+
from forge.interfaces import Policy as PolicyInterface
44+
from forge.types import ProcessConfig
45+
4646

4747
logger = logging.getLogger(__name__)
4848

src/forge/services/reference_actor.py renamed to src/forge/services/reference_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
@dataclass
3939
class TitanRefModel(Service):
4040
"""
41-
Represents a reference actor leveraging a torchtitan model for execution
41+
Represents a reference service leveraging a torchtitan model for execution
4242
4343
Intended for generating reference_logprobs - for example in KL Divergence
4444
"""
@@ -155,7 +155,7 @@ def compute_logprobs(
155155
# Maintained to keep Old GRPO app prior to full migration off of HF
156156
class HuggingFaceRefModel(Service):
157157
"""
158-
Represents a reference actor leveraging HuggingFace for execution
158+
Represents a reference service leveraging HuggingFace for execution
159159
"""
160160

161161
def __init__(self, model_name, device: torch.device | None = None):

src/forge/services/replay_buffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from dataclasses import dataclass
99
from typing import Any
1010

11-
from forge.controller import Service
12-
1311
from monarch.actor import endpoint
1412

13+
from forge.controller import Service
14+
1515

1616
@dataclass
1717
class ReplayBuffer(Service):

0 commit comments

Comments
 (0)