Skip to content

Commit 7640d4d

Browse files
committed
[feat] add Serializer for rpc server
1 parent f24294b commit 7640d4d

File tree

4 files changed

+686
-0
lines changed

4 files changed

+686
-0
lines changed

areal/scheduler/rpc/app.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
import json
5+
from collections.abc import Iterable, Mapping
6+
from typing import Any
7+
8+
from flask import Flask
9+
from flask_jsonrpc import JSONRPC
10+
11+
from areal.api.cli_args import NameResolveConfig
12+
from areal.utils import logging, name_resolve
13+
from areal.utils.dynamic_import import import_from_string
14+
from areal.utils.seeding import set_random_seed
15+
from areal.utils.stats_tracker import export_all
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def _ensure_jsonable(value: Any) -> Any:
21+
"""Convert Python objects into JSON-serialisable structures."""
22+
if isinstance(value, (str, int, float, bool)) or value is None:
23+
return value
24+
if isinstance(value, Mapping):
25+
return {str(k): _ensure_jsonable(v) for k, v in value.items()}
26+
if isinstance(value, Iterable) and not isinstance(value, (bytes, bytearray)):
27+
return [_ensure_jsonable(v) for v in value]
28+
try:
29+
json.dumps(value)
30+
except TypeError:
31+
return repr(value)
32+
return value
33+
34+
35+
class EngineRPCServer:
36+
engine = None
37+
38+
def create_engine(
39+
self,
40+
payload: dict[str, Any],
41+
) -> dict[str, Any]:
42+
if not isinstance(payload, Mapping):
43+
raise ValueError("payload must be a mapping object")
44+
45+
cls_path = payload.get("engine_class")
46+
if not cls_path:
47+
raise ValueError("engine_class is required")
48+
49+
cls_args = payload.get("engine_args", [])
50+
cls_kwargs = payload.get("engine_kwargs", {})
51+
init_args = payload.get("initialize_args", [])
52+
init_kwargs = payload.get("initialize_kwargs", {})
53+
54+
try:
55+
engine_cls = import_from_string(cls_path)
56+
except (ImportError, AttributeError, ValueError) as exc:
57+
raise ValueError(
58+
f"Failed to import engine class '{cls_path}': {exc}"
59+
) from exc
60+
61+
try:
62+
engine = engine_cls(*cls_args, **cls_kwargs)
63+
except Exception as exc: # pylint: disable=broad-except
64+
raise RuntimeError(f"Failed to instantiate engine: {exc}") from exc
65+
66+
try:
67+
init_result = engine.initialize(*init_args, **init_kwargs)
68+
except Exception as exc: # pylint: disable=broad-except
69+
raise RuntimeError(f"Engine initialise failed: {exc}") from exc
70+
71+
self.engine = engine
72+
return {
73+
"status": "initialized",
74+
"initialize_result": _ensure_jsonable(init_result),
75+
}
76+
77+
def call_engine(self, payload: dict[str, Any]) -> dict[str, Any]:
78+
if not isinstance(payload, Mapping):
79+
raise ValueError("payload must be a mapping object")
80+
81+
method_name = payload.get("method")
82+
if method_name is None:
83+
raise ValueError("method is required")
84+
85+
if self.engine is None:
86+
raise RuntimeError("Engine not found")
87+
88+
engine = self.engine
89+
args = payload.get("args", [])
90+
kwargs = payload.get("kwargs", {})
91+
92+
try:
93+
target = getattr(engine, method_name)
94+
except AttributeError as exc:
95+
raise ValueError(f"Engine has no method '{method_name}'") from exc
96+
97+
try:
98+
result = target(*args, **kwargs)
99+
except Exception as exc: # pylint: disable=broad-except
100+
raise RuntimeError(f"Engine call failed: {exc}") from exc
101+
102+
return {
103+
"method": method_name,
104+
"result": _ensure_jsonable(result),
105+
}
106+
107+
def configure(self, payload: dict[str, Any]) -> dict[str, Any]:
108+
if not isinstance(payload, Mapping):
109+
raise ValueError("payload must be a mapping object")
110+
111+
seed_cfg = payload.get("seed")
112+
if seed_cfg:
113+
base_seed = seed_cfg.get("base_seed")
114+
key = seed_cfg.get("key", "default")
115+
if base_seed is None:
116+
raise ValueError("seed.base_seed is required")
117+
set_random_seed(int(base_seed), str(key))
118+
logger.info("Random seed configured: base=%s key=%s", base_seed, key)
119+
120+
name_resolve_cfg = payload.get("name_resolve")
121+
if name_resolve_cfg:
122+
try:
123+
config = NameResolveConfig(**name_resolve_cfg)
124+
except TypeError as exc:
125+
raise ValueError(f"Invalid name_resolve payload: {exc}") from exc
126+
name_resolve.reconfigure(config)
127+
logger.info("Name resolver reconfigured: type=%s", config.type)
128+
129+
return {"status": "configured"}
130+
131+
def export_stats(self, reset: bool = True) -> dict[str, Any]:
132+
try:
133+
stats = export_all(reset=reset)
134+
except Exception as exc: # pylint: disable=broad-except
135+
logger.exception("Failed to export stats")
136+
raise RuntimeError(f"Failed to export stats: {exc}") from exc
137+
138+
return {"stats": _ensure_jsonable(stats), "reset": reset}
139+
140+
def health(self) -> dict[str, str]:
141+
return {"status": "ok"}
142+
143+
144+
def create_app():
145+
app = Flask(__name__)
146+
engine_rpc_server = EngineRPCServer()
147+
jsonrpc = JSONRPC(app, "/api", enable_web_browsable_api=False)
148+
149+
@jsonrpc.method("App.create_engine")
150+
def rpc_create_engine(payload: dict[str, Any]): # type: ignore[override]
151+
return engine_rpc_server.create_engine(payload)
152+
153+
@jsonrpc.method("App.call_engine")
154+
def rpc_call_engine(payload: dict[str, Any]): # type: ignore[override]
155+
return engine_rpc_server.call_engine(payload)
156+
157+
@jsonrpc.method("App.configure")
158+
def rpc_configure(payload: dict[str, Any]): # type: ignore[override]
159+
return engine_rpc_server.configure(payload)
160+
161+
@jsonrpc.method("App.health")
162+
def rpc_health(): # type: ignore[override]
163+
return engine_rpc_server.health()
164+
165+
@jsonrpc.method("App.export_stats")
166+
def rpc_export_stats(reset: bool = True): # type: ignore[override]
167+
return engine_rpc_server.export_stats(reset)
168+
169+
return app
170+
171+
172+
if __name__ == "__main__":
173+
parser = argparse.ArgumentParser()
174+
parser.add_argument("--host", default="0.0.0.0")
175+
parser.add_argument("--port", default=5000, type=int, required=False)
176+
args = parser.parse_args()
177+
178+
app = create_app()
179+
app.run(host=args.host, port=args.port)

0 commit comments

Comments
 (0)