Skip to content

Commit c96400d

Browse files
committed
merge
2 parents 7f35b26 + 839664c commit c96400d

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

src/forge/actors/generator.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,6 @@
1616

1717
import torch
1818
import torchstore as ts
19-
20-
from forge.actors._torchstore_utils import (
21-
extract_param_name,
22-
get_dcp_whole_state_dict_key,
23-
get_param_key,
24-
get_param_prefix,
25-
load_tensor_from_dcp,
26-
)
27-
28-
from forge.controller import (
29-
ForgeActor,
30-
get_proc_mesh,
31-
host_mesh_from_proc,
32-
stop_proc_mesh,
33-
)
34-
from forge.data_models.completion import Completion
35-
from forge.data_models.prompt import to_prompt
36-
from forge.env import TORCHSTORE_USE_RDMA
37-
from forge.observability.metrics import record_metric, Reduce
38-
from forge.observability.perf_tracker import Tracer
39-
from forge.types import ProcessConfig
4019
from monarch.actor import current_rank, endpoint, ProcMesh
4120
from vllm.config import VllmConfig
4221

@@ -61,6 +40,27 @@
6140
from vllm.v1.structured_output import StructuredOutputManager
6241
from vllm.worker.worker_base import WorkerWrapperBase
6342

43+
from forge.actors._torchstore_utils import (
44+
extract_param_name,
45+
get_dcp_whole_state_dict_key,
46+
get_param_key,
47+
get_param_prefix,
48+
load_tensor_from_dcp,
49+
)
50+
51+
from forge.controller import (
52+
ForgeActor,
53+
get_proc_mesh,
54+
host_mesh_from_proc,
55+
stop_proc_mesh,
56+
)
57+
from forge.data_models.completion import Completion
58+
from forge.data_models.prompt import to_prompt
59+
from forge.env import TORCHSTORE_USE_RDMA
60+
from forge.observability.metrics import record_metric, Reduce
61+
from forge.observability.perf_tracker import Tracer
62+
from forge.types import ProcessConfig
63+
6464
logger = logging.getLogger(__name__)
6565
logger.setLevel(logging.INFO)
6666

@@ -160,7 +160,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
160160
generator_proc = await get_proc_mesh(
161161
process_config=generator_proc_config, host_mesh=host_mesh
162162
)
163-
164163
# TODO - expand support so name can stick within kwargs
165164
actor_name = kwargs.pop("name", cls.__name__)
166165
generator = generator_proc.spawn(

0 commit comments

Comments
 (0)