Skip to content

Commit 8a98f89

Browse files
committed
refactor server
1 parent edb83cf commit 8a98f89

File tree

8 files changed

+1056
-137
lines changed

8 files changed

+1056
-137
lines changed

areal/scheduler/rpc/api.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import typing as t
2+
from enum import Enum
3+
from typing import Any
4+
5+
import flask_jsonrpc.types.params as tp
6+
from pydantic import BaseModel
7+
from typing_extensions import Self
8+
9+
from areal.api.cli_args import (
10+
InferenceEngineConfig,
11+
NameResolveConfig,
12+
TrainEngineConfig,
13+
)
14+
15+
16+
class BaseException(Exception):
17+
def __init__(
18+
self: Self, message: t.Annotated[str, tp.Summary("Exception reason")]
19+
) -> None:
20+
super().__init__(message)
21+
22+
23+
class InvalidParamsException(BaseException):
24+
def __init__(self: Self, params: t.Annotated[str, tp.Summary("")]) -> None:
25+
super().__init__(message=f"Invalid Params Received: {params}")
26+
27+
28+
class EngineNameEnum(str, Enum):
29+
FSDP = "fsdp"
30+
MEGATRON = "megatron"
31+
SGLANG_REMOTE = "sglang_remote"
32+
VLLM_REMOTE = "vllm_remote"
33+
34+
35+
class ConfigurePayload(BaseModel):
36+
seed_cfg: dict
37+
name_resolve: NameResolveConfig
38+
39+
40+
class CreateEnginePayload(BaseModel):
41+
config: TrainEngineConfig | InferenceEngineConfig
42+
class_name: EngineNameEnum
43+
initial_args: dict[str, Any]
44+
45+
46+
class CallEnginePayload(BaseModel):
47+
method: str
48+
args: list[Any]
49+
kwargs: dict[str, Any]
50+
51+
52+
class Response(BaseModel):
53+
success: bool
54+
message: str
55+
data: Any | None = None

areal/scheduler/rpc/client.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import json
2+
from typing import Any
3+
4+
import requests
5+
6+
from areal.api.cli_args import (
7+
InferenceEngineConfig,
8+
NameResolveConfig,
9+
TrainEngineConfig,
10+
)
11+
from areal.scheduler.rpc.api import (
12+
CallEnginePayload,
13+
ConfigurePayload,
14+
CreateEnginePayload,
15+
EngineNameEnum,
16+
Response,
17+
)
18+
from areal.utils import logging
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class EngineRPCClient:
24+
def __init__(self, host: str = "localhost", port: int = 5000):
25+
self.host = host
26+
self.port = port
27+
self.base_url = f"http://{host}:{port}/api"
28+
self._request_id = 0
29+
30+
def _send_request(self, method: str, params: dict[str, Any]) -> Response:
31+
"""Send a JSON-RPC request and parse the standard Response envelope."""
32+
33+
self._request_id += 1
34+
payload = {
35+
"jsonrpc": "2.0",
36+
"method": method,
37+
"params": params,
38+
"id": self._request_id,
39+
}
40+
headers = {"Content-Type": "application/json"}
41+
try:
42+
response = requests.post(
43+
self.base_url, data=json.dumps(payload), headers=headers, timeout=300
44+
)
45+
response.raise_for_status()
46+
result = response.json()
47+
if "error" in result:
48+
raise RuntimeError(f"JSON-RPC error: {result['error']}")
49+
response_data = result["result"]
50+
return Response(
51+
success=response_data["success"],
52+
message=response_data["message"],
53+
data=response_data.get("data"),
54+
)
55+
except Exception as exc: # pragma: no cover - network failure path
56+
logger.error(f"Request failed: {exc}")
57+
raise
58+
59+
def create_engine(
60+
self,
61+
config: TrainEngineConfig | InferenceEngineConfig,
62+
class_name: EngineNameEnum,
63+
initial_args: dict[str, Any],
64+
) -> None:
65+
"""Create a remote engine instance.
66+
67+
This mirrors the payload structure used in test_rpc_integration and
68+
areal.scheduler.rpc.server.create_app.
69+
"""
70+
71+
payload = CreateEnginePayload(
72+
config=config,
73+
class_name=class_name,
74+
initial_args=initial_args,
75+
)
76+
response = self._send_request(
77+
"areal.create_engine",
78+
{"payload": payload.model_dump()},
79+
)
80+
if not response.success:
81+
raise RuntimeError(f"Failed to create engine: {response.message}")
82+
83+
def call_engine(self, method: str, *args, **kwargs) -> Any:
84+
"""Call a method on the remote engine instance."""
85+
86+
payload = CallEnginePayload(
87+
method=method,
88+
args=list(args),
89+
kwargs=kwargs,
90+
)
91+
response = self._send_request(
92+
"areal.call_engine",
93+
{"payload": payload.model_dump()},
94+
)
95+
if not response.success:
96+
raise RuntimeError(f"Failed to call engine: {response.message}")
97+
return response.data
98+
99+
def configure(
100+
self,
101+
seed_cfg: dict[str, Any] | None = None,
102+
name_resolve: NameResolveConfig | None = None,
103+
) -> None:
104+
"""Configure global settings such as random seed and name_resolve."""
105+
106+
payload = ConfigurePayload(
107+
seed_cfg=seed_cfg or {},
108+
name_resolve=name_resolve or NameResolveConfig(),
109+
)
110+
response = self._send_request(
111+
"areal.configure",
112+
{"payload": payload.model_dump()},
113+
)
114+
if not response.success:
115+
raise RuntimeError(f"Failed to configure: {response.message}")
116+
117+
def health(self) -> Response:
118+
"""Health check for the remote engine server."""
119+
120+
return self._send_request("areal.health", {})
121+
122+
def export_stats(self, reset: bool = True) -> Any:
123+
"""Export statistics from the remote engine server."""
124+
125+
response = self._send_request("areal.export_stats", {"reset": reset})
126+
if not response.success:
127+
raise RuntimeError(f"Failed to export stats: {response.message}")
128+
return response.data

areal/scheduler/rpc/rpc_client.py

Lines changed: 0 additions & 137 deletions
This file was deleted.

0 commit comments

Comments
 (0)