|
16 | 16 |
|
17 | 17 | import torch |
18 | 18 | import torchstore as ts |
19 | | - |
20 | | -from forge.actors._torchstore_utils import ( |
21 | | - extract_param_name, |
22 | | - get_dcp_whole_state_dict_key, |
23 | | - get_param_key, |
24 | | - get_param_prefix, |
25 | | - load_tensor_from_dcp, |
26 | | -) |
27 | | - |
28 | | -from forge.controller import ( |
29 | | - ForgeActor, |
30 | | - get_proc_mesh, |
31 | | - host_mesh_from_proc, |
32 | | - stop_proc_mesh, |
33 | | -) |
34 | | -from forge.data_models.completion import Completion |
35 | | -from forge.data_models.prompt import to_prompt |
36 | | -from forge.env import TORCHSTORE_USE_RDMA |
37 | | -from forge.observability.metrics import record_metric, Reduce |
38 | | -from forge.observability.perf_tracker import Tracer |
39 | | -from forge.types import ProcessConfig |
40 | 19 | from monarch.actor import current_rank, endpoint, ProcMesh |
41 | 20 | from vllm.config import VllmConfig |
42 | 21 |
|
|
61 | 40 | from vllm.v1.structured_output import StructuredOutputManager |
62 | 41 | from vllm.worker.worker_base import WorkerWrapperBase |
63 | 42 |
|
| 43 | +from forge.actors._torchstore_utils import ( |
| 44 | + extract_param_name, |
| 45 | + get_dcp_whole_state_dict_key, |
| 46 | + get_param_key, |
| 47 | + get_param_prefix, |
| 48 | + load_tensor_from_dcp, |
| 49 | +) |
| 50 | + |
| 51 | +from forge.controller import ( |
| 52 | + ForgeActor, |
| 53 | + get_proc_mesh, |
| 54 | + host_mesh_from_proc, |
| 55 | + stop_proc_mesh, |
| 56 | +) |
| 57 | +from forge.data_models.completion import Completion |
| 58 | +from forge.data_models.prompt import to_prompt |
| 59 | +from forge.env import TORCHSTORE_USE_RDMA |
| 60 | +from forge.observability.metrics import record_metric, Reduce |
| 61 | +from forge.observability.perf_tracker import Tracer |
| 62 | +from forge.types import ProcessConfig |
| 63 | + |
64 | 64 | logger = logging.getLogger(__name__) |
65 | 65 | logger.setLevel(logging.INFO) |
66 | 66 |
|
@@ -160,7 +160,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] |
160 | 160 | generator_proc = await get_proc_mesh( |
161 | 161 | process_config=generator_proc_config, host_mesh=host_mesh |
162 | 162 | ) |
163 | | - |
164 | 163 | # TODO - expand support so name can stick within kwargs |
165 | 164 | actor_name = kwargs.pop("name", cls.__name__) |
166 | 165 | generator = generator_proc.spawn( |
|
0 commit comments