|
17 | 17 |
|
18 | 18 | import torch |
19 | 19 | import torchstore as ts |
| 20 | + |
| 21 | +from forge.actors._torchstore_utils import ( |
| 22 | + extract_param_name, |
| 23 | + get_dcp_whole_state_dict_key, |
| 24 | + get_param_key, |
| 25 | + get_param_prefix, |
| 26 | + load_tensor_from_dcp, |
| 27 | + rdma_available, |
| 28 | +) |
| 29 | + |
| 30 | +from forge.controller import ( |
| 31 | + ForgeActor, |
| 32 | + get_proc_mesh, |
| 33 | + host_mesh_from_proc, |
| 34 | + stop_proc_mesh, |
| 35 | +) |
| 36 | +from forge.data_models.completion import Completion |
| 37 | +from forge.data_models.prompt import to_prompt |
| 38 | +from forge.observability.metrics import record_metric, Reduce |
| 39 | +from forge.observability.perf_tracker import Tracer |
| 40 | +from forge.types import ProcessConfig |
| 41 | +from forge.util._shared_tensor import SharedTensor, SharedTensorHandle |
20 | 42 | from monarch.actor import current_rank, endpoint, ProcMesh, this_host |
21 | 43 |
|
22 | 44 | from vllm.config import VllmConfig |
|
42 | 64 | from vllm.v1.structured_output import StructuredOutputManager |
43 | 65 | from vllm.worker.worker_base import WorkerWrapperBase |
44 | 66 |
|
45 | | -from forge.actors._torchstore_utils import ( |
46 | | - extract_param_name, |
47 | | - get_dcp_whole_state_dict_key, |
48 | | - get_param_key, |
49 | | - get_param_prefix, |
50 | | - load_tensor_from_dcp, |
51 | | - rdma_available, |
52 | | -) |
53 | | - |
54 | | -from forge.controller import ( |
55 | | - ForgeActor, |
56 | | - get_proc_mesh, |
57 | | - host_mesh_from_proc, |
58 | | - stop_proc_mesh, |
59 | | -) |
60 | | -from forge.data_models.completion import Completion |
61 | | -from forge.data_models.prompt import to_prompt |
62 | | -from forge.observability.metrics import record_metric, Reduce |
63 | | -from forge.observability.perf_tracker import Tracer |
64 | | -from forge.types import ProcessConfig |
65 | | -from forge.util._shared_tensor import SharedTensor, SharedTensorHandle |
66 | | - |
67 | 67 | logger = logging.getLogger(__name__) |
68 | 68 | logger.setLevel(logging.INFO) |
69 | 69 |
|
|
0 commit comments