Skip to content

Commit 20a8acc

Browse files
authored
add rpc server/client for the single-controller mode (#377)
* add rpc server/client for the single-controller mode * add type string, remove chinese comment, update doc string * process distributed batch in rpc server * add unittest --------- Co-authored-by: 仲青 <[email protected]>
1 parent b103d7b commit 20a8acc

File tree

4 files changed

+631
-1
lines changed

4 files changed

+631
-1
lines changed

areal/scheduler/rpc/rpc_client.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import gzip
2+
import time
3+
from http import HTTPStatus
4+
from typing import Any, Union
5+
6+
import cloudpickle
7+
import requests
8+
9+
from areal.api.cli_args import InferenceEngineConfig, TrainEngineConfig
10+
from areal.api.engine_api import InferenceEngine, TrainEngine
11+
from areal.utils import logging
12+
from areal.utils.http import response_ok, response_retryable
13+
14+
logger = logging.getLogger("RPCClient")
15+
16+
17+
class RPCClient:
18+
def __init__(self):
19+
self._addrs = {}
20+
21+
def register(self, worker_id: str, ip: str, port: int) -> None:
22+
self._addrs[worker_id] = (ip, port)
23+
logger.info(f"Registered worker {worker_id} at {ip}:{port}")
24+
25+
def create_engine(
26+
self,
27+
worker_id: str,
28+
engine_obj: Union[InferenceEngine, TrainEngine],
29+
init_config: Union[InferenceEngineConfig, TrainEngineConfig],
30+
) -> None:
31+
ip, port = self._addrs[worker_id]
32+
url = f"http://{ip}:{port}/create_engine"
33+
logger.info(f"send create_engine to {worker_id} ({ip}:{port})")
34+
payload = (engine_obj, init_config)
35+
serialized_data = cloudpickle.dumps(payload)
36+
serialized_obj = gzip.compress(serialized_data)
37+
resp = requests.post(url, data=serialized_obj)
38+
logger.info(
39+
f"send create_engine to {worker_id} ({ip}:{port}), status={resp.status_code}"
40+
)
41+
if resp.status_code == HTTPStatus.OK:
42+
logger.info(f"create engine success.")
43+
return cloudpickle.loads(resp.content)
44+
else:
45+
logger.error(f"Failed to create engine, {resp.status_code}, {resp.content}")
46+
raise RuntimeError(
47+
f"Failed to create engine, {resp.status_code}, {resp.content}"
48+
)
49+
50+
def call_engine(
51+
self, worker_id: str, method: str, max_retries: int = 3, *args, **kwargs
52+
) -> Any:
53+
"""
54+
call the rpc server with method name and args, retry on failure
55+
56+
Parameters
57+
----------
58+
worker_id: str
59+
the id of the worker to call
60+
method: str
61+
the method name to call
62+
max_retries: int
63+
max retries on failure
64+
*args:
65+
args to pass to the method
66+
**kwargs:
67+
kwargs to pass to the method
68+
69+
Returns
70+
-------
71+
the deserialized result from the rpc server
72+
"""
73+
req = (method, args, kwargs)
74+
serialized_data = cloudpickle.dumps(req)
75+
76+
return self._call_engine_with_serialized_data(
77+
worker_id, serialized_data, max_retries
78+
)
79+
80+
def _call_engine_with_serialized_data(
81+
self, worker_id: str, serialized_data: bytes, max_retries=3
82+
) -> Any:
83+
"""
84+
call the rpc server with serialized data, retry on failure
85+
86+
Parameters
87+
----------
88+
worker_id: str
89+
the id of the worker to call
90+
serialized_data: bytes
91+
the serialized data to send
92+
max_retries: int
93+
max retries on failure
94+
95+
Returns
96+
-------
97+
the deserialized result from the rpc server
98+
"""
99+
ip, port = self._addrs[worker_id]
100+
url = f"http://{ip}:{port}/call"
101+
last_exception = None
102+
103+
for attempt in range(max_retries):
104+
try:
105+
resp = requests.post(url, data=serialized_data, timeout=7200)
106+
logger.info(
107+
f"Sent call to {worker_id} ({ip}:{port}), status={resp.status_code}, attempt {attempt + 1}/{max_retries}"
108+
)
109+
110+
if response_ok(resp.status_code):
111+
return cloudpickle.loads(resp.content)
112+
elif response_retryable(resp.status_code):
113+
last_exception = RuntimeError(
114+
f"Retryable HTTP status {resp.status_code}: {resp.content}"
115+
)
116+
else:
117+
raise RuntimeError(
118+
f"Non-retryable HTTP error: {resp.status_code} - {resp.content}"
119+
)
120+
121+
except (RuntimeError, TimeoutError) as e:
122+
logger.error(f"stop retrying, error on attempt {attempt + 1}: {e}")
123+
raise e
124+
except Exception as e:
125+
last_exception = e
126+
logger.error(f"error on attempt {attempt + 1}: {e}")
127+
128+
if last_exception is not None:
129+
if attempt < max_retries - 1:
130+
logger.warning(
131+
f"Retrying in 1 second... ({attempt + 1}/{max_retries})"
132+
)
133+
time.sleep(1)
134+
continue
135+
else:
136+
logger.error(f"Max retries exceeded for {url}")
137+
raise last_exception

areal/scheduler/rpc/rpc_server.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import argparse
2+
import gzip
3+
import os
4+
import traceback
5+
from http import HTTPStatus
6+
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
7+
from typing import AnyStr
8+
9+
import cloudpickle
10+
from tensordict import TensorDict
11+
12+
from areal.api.controller_api import DistributedBatch
13+
from areal.controller.batch import DistributedBatchMemory
14+
from areal.utils import logging
15+
16+
logger = logging.getLogger("RPCServer")
17+
18+
19+
def process_input_to_distributed_batch(*args, **kwargs):
20+
for i in range(len(args)):
21+
if isinstance(args[i], DistributedBatch):
22+
args = list(args)
23+
args[i] = args[i].get_data()
24+
args = tuple(args)
25+
26+
for k in list(kwargs.keys()):
27+
if isinstance(kwargs[k], DistributedBatch):
28+
kwargs[k] = kwargs[k].get_data()
29+
30+
return args, kwargs
31+
32+
33+
def process_output_to_distributed_batch(result):
34+
if isinstance(result, dict):
35+
return DistributedBatchMemory.from_dict(result)
36+
elif isinstance(result, TensorDict):
37+
return DistributedBatchMemory.from_dict(result.to_dict())
38+
elif isinstance(result, (list, tuple)):
39+
return DistributedBatchMemory.from_list(list(result))
40+
else:
41+
return result
42+
43+
44+
class EngineRPCServer(BaseHTTPRequestHandler):
45+
engine = None
46+
47+
def _read_body(self, timeout=120.0) -> AnyStr:
48+
old_timeout = None
49+
try:
50+
length = int(self.headers["Content-Length"])
51+
old_timeout = self.request.gettimeout()
52+
logger.info(f"Receive rpc call, path: {self.path}, timeout: {old_timeout}")
53+
# set max read timeout = 120s here, if read hang raise exception
54+
self.request.settimeout(timeout)
55+
return self.rfile.read(length)
56+
except Exception as e:
57+
raise e
58+
finally:
59+
self.request.settimeout(old_timeout)
60+
61+
def do_POST(self):
62+
data = None
63+
try:
64+
data = self._read_body()
65+
except Exception as e:
66+
self.send_response(
67+
HTTPStatus.REQUEST_TIMEOUT
68+
) # 408 means read request timeout
69+
self.end_headers()
70+
self.wfile.write(
71+
f"Exception: {e}\n{traceback.format_exc()}".encode("utf-8")
72+
)
73+
logger.error(f"Exception in do_POST: {e}\n{traceback.format_exc()}")
74+
return
75+
76+
try:
77+
if self.path == "/create_engine":
78+
decompressed_data = gzip.decompress(data)
79+
engine_obj, init_args = cloudpickle.loads(decompressed_data)
80+
EngineRPCServer.engine = engine_obj
81+
result = EngineRPCServer.engine.initialize(init_args)
82+
logger.info(f"Engine created and initialized on RPC server: {result}")
83+
self.send_response(HTTPStatus.OK)
84+
self.end_headers()
85+
self.wfile.write(cloudpickle.dumps(result))
86+
elif self.path == "/call":
87+
if EngineRPCServer.engine is None:
88+
self.send_response(HTTPStatus.INTERNAL_SERVER_ERROR)
89+
self.end_headers()
90+
self.wfile.write(b"Engine is none")
91+
logger.error("Call received but engine is none.")
92+
return
93+
action, args, kwargs = cloudpickle.loads(data)
94+
method = getattr(EngineRPCServer.engine, action)
95+
# NOTE: DO NOT print args here, args may be a very huge tensor
96+
logger.info(f"RPC server calling engine method: {action}")
97+
args, kwargs = process_input_to_distributed_batch(*args, **kwargs)
98+
result = method(*args, **kwargs)
99+
result = process_output_to_distributed_batch(result)
100+
self.send_response(HTTPStatus.OK)
101+
self.end_headers()
102+
self.wfile.write(cloudpickle.dumps(result))
103+
else:
104+
self.send_response(HTTPStatus.NOT_FOUND)
105+
self.end_headers()
106+
except Exception as e:
107+
self.send_response(HTTPStatus.INTERNAL_SERVER_ERROR)
108+
self.end_headers()
109+
self.wfile.write(
110+
f"Exception: {e}\n{traceback.format_exc()}".encode("utf-8")
111+
)
112+
logger.error(f"Exception in do_POST: {e}\n{traceback.format_exc()}")
113+
114+
115+
def start_rpc_server(port):
116+
server = ThreadingHTTPServer(("0.0.0.0", port), EngineRPCServer)
117+
server.serve_forever()
118+
119+
120+
def get_serve_port(args):
121+
port = args.port
122+
port_str = os.environ.get("PORT_LIST", "").strip()
123+
124+
# Check if PORT_LIST is set
125+
if port_str:
126+
# Split by comma and strip whitespace
127+
ports = [p.strip() for p in port_str.split(",")]
128+
# Use the first valid port from the list
129+
if ports and ports[0]:
130+
try:
131+
return int(ports[0])
132+
except ValueError:
133+
logger.warning(
134+
f"Invalid port '{ports[0]}' in PORT_LIST. Falling back to --port argument."
135+
)
136+
return port
137+
138+
139+
if __name__ == "__main__":
140+
parser = argparse.ArgumentParser()
141+
142+
parser.add_argument("--port", type=int, required=False)
143+
144+
args, unknown = parser.parse_known_args()
145+
port = get_serve_port(args)
146+
147+
logger.info(f"About to start RPC server on {port}")
148+
149+
start_rpc_server(port)

0 commit comments

Comments
 (0)