Skip to content

Commit 683dfb0

Browse files
authored
[RPC] Report RPC Session Timeout to Client Instead of "kShutdown" (#15187)
By using RPC server in NPU board, at some time a compiled model will hang the NPU, because of the buggy operator libraries of NPU toolchain, so we must to use the session_timeout to ensure the board resource can be released by the hang jobs. Currently the handling of session timeout error in RPC server is not good, it just kill the server loop sub process, then in the destructor of class `RPCEndpoint` will send the code of `kShutdown` to the RPC client, but the RPC client expect receive the code of `kReturn` or `kException`, so users will see the error message that like the one reported in #15151, this error report will make users very confused and don't know what's happened. When using tuning to search a good schedule for operators, we only want to ignore the RPC session timeout error that indicate the schedule generated is an illegal one, but other error reported by the RPC server may help us find the potential bug of our tool chain built on top of TVM, so the RPC session timeout error should be split to a standalone TVM error class. This PR implemented these requirements by sending the RPC session timeout error message as a PRC server exception to the RPC client before kill the server loop sub process.
1 parent 977b4b2 commit 683dfb0

File tree

5 files changed

+121
-45
lines changed

5 files changed

+121
-45
lines changed

python/tvm/error.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ class RPCError(TVMError):
6161
"""Error thrown by the remote server handling the RPC call."""
6262

6363

64+
@register_error
65+
class RPCSessionTimeoutError(RPCError, TimeoutError):
66+
"""Error thrown by the remote server when the RPC session has expired."""
67+
68+
6469
@register_error
6570
class OpError(TVMError):
6671
"""Base class of all operator errors in frontends."""

python/tvm/rpc/server.py

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
- {server|client}:device-type[:random-key] [-timeout=timeout]
2626
"""
2727
# pylint: disable=invalid-name
28+
import os
2829
import ctypes
2930
import socket
3031
import select
@@ -118,16 +119,6 @@ def download_linked_module(file_name):
118119
return temp
119120

120121

121-
def _serve_loop(sock, addr, load_library, work_path=None):
122-
"""Server loop"""
123-
sockfd = sock.fileno()
124-
temp = _server_env(load_library, work_path)
125-
_ffi_api.ServerLoop(sockfd)
126-
if not work_path:
127-
temp.remove()
128-
logger.info("Finish serving %s", addr)
129-
130-
131122
def _parse_server_opt(opts):
132123
# parse client options
133124
ret = {}
@@ -137,6 +128,47 @@ def _parse_server_opt(opts):
137128
return ret
138129

139130

131+
def _serving(sock, addr, opts, load_library):
132+
logger.info(f"connected from {addr}")
133+
work_path = utils.tempdir()
134+
old_cwd = os.getcwd()
135+
os.chdir(work_path.path) # Avoiding file name conflict between sessions.
136+
logger.info(f"start serving at {work_path.path}")
137+
138+
def _serve_loop():
139+
_server_env(load_library, work_path)
140+
_ffi_api.ServerLoop(sock.fileno())
141+
142+
server_proc = multiprocessing.Process(target=_serve_loop)
143+
server_proc.start()
144+
server_proc.join(opts.get("timeout", None)) # Wait until finish or timeout.
145+
146+
if server_proc.is_alive():
147+
logger.info("timeout in RPC session, kill..")
148+
_ffi_api.ReturnException(
149+
sock.fileno(),
150+
f'RPCSessionTimeoutError: Your {opts["timeout"]}s session has expired, '
151+
f'try to increase the "session_timeout" value.',
152+
)
153+
154+
try:
155+
import psutil # pylint: disable=import-outside-toplevel
156+
157+
# Terminate worker children firstly.
158+
for child in psutil.Process(server_proc.pid).children(recursive=True):
159+
child.terminate()
160+
except ImportError:
161+
# Don't dependent `psutil` hardly, because it isn't a pure Python
162+
# package and maybe hard to be installed on some platforms.
163+
pass
164+
server_proc.terminate()
165+
166+
logger.info(f"finish serving {addr}")
167+
os.chdir(old_cwd)
168+
work_path.remove()
169+
sock.close()
170+
171+
140172
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
141173
"""Listening loop of the server."""
142174

@@ -237,30 +269,7 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2):
237269
raise exc
238270

239271
# step 3: serving
240-
work_path = utils.tempdir()
241-
logger.info("connection from %s", addr)
242-
server_proc = multiprocessing.Process(
243-
target=_serve_loop, args=(conn, addr, load_library, work_path)
244-
)
245-
246-
server_proc.start()
247-
# close from our side.
248-
conn.close()
249-
# wait until server process finish or timeout
250-
server_proc.join(opts.get("timeout", None))
251-
252-
if server_proc.is_alive():
253-
logger.info("Timeout in RPC session, kill..")
254-
# pylint: disable=import-outside-toplevel
255-
import psutil
256-
257-
parent = psutil.Process(server_proc.pid)
258-
# terminate worker children
259-
for child in parent.children(recursive=True):
260-
child.terminate()
261-
# terminate the worker
262-
server_proc.terminate()
263-
work_path.remove()
272+
_serving(conn, addr, opts, load_library)
264273

265274

266275
def _connect_proxy_loop(addr, key, load_library):
@@ -285,15 +294,8 @@ def _connect_proxy_loop(addr, key, load_library):
285294
raise RuntimeError(f"{str(addr)} is not RPC Proxy")
286295
keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
287296
remote_key = py_str(base.recvall(sock, keylen))
288-
opts = _parse_server_opt(remote_key.split()[1:])
289-
logger.info("connected to %s", str(addr))
290-
process = multiprocessing.Process(target=_serve_loop, args=(sock, addr, load_library))
291-
process.start()
292-
sock.close()
293-
process.join(opts.get("timeout", None))
294-
if process.is_alive():
295-
logger.info("Timeout in RPC session, kill..")
296-
process.terminate()
297+
298+
_serving(sock, addr, _parse_server_opt(remote_key.split()[1:]), load_library)
297299
retry_count = 0
298300
except (socket.error, IOError) as err:
299301
retry_count += 1

src/runtime/rpc/rpc_endpoint.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
#include "../../support/arena.h"
4242
#include "../../support/ring_buffer.h"
43+
#include "../../support/utils.h"
4344
#include "../object_internal.h"
4445
#include "rpc_local_session.h"
4546

@@ -372,8 +373,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
372373
if (code == RPCCode::kException) {
373374
// switch to the state before sending exception.
374375
this->SwitchToState(kRecvPacketNumBytes);
375-
std::string msg = args[0];
376-
LOG(FATAL) << "RPCError: Error caught from RPC call:\n" << msg;
376+
String msg = args[0];
377+
if (!support::StartsWith(msg, "RPCSessionTimeoutError: ")) {
378+
msg = "RPCError: Error caught from RPC call:\n" + msg;
379+
}
380+
LOG(FATAL) << msg;
377381
}
378382

379383
ICHECK(setreturn != nullptr) << "fsetreturn not available";

src/runtime/rpc/rpc_socket_impl.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,5 +142,39 @@ TVM_REGISTER_GLOBAL("rpc.ServerLoop").set_body([](TVMArgs args, TVMRetValue* rv)
142142
}
143143
});
144144

145+
class SimpleSockHandler : public dmlc::Stream {
146+
// Things that will interface with user directly.
147+
public:
148+
explicit SimpleSockHandler(int sockfd)
149+
: sock_(static_cast<support::TCPSocket::SockType>(sockfd)) {}
150+
using dmlc::Stream::Read;
151+
using dmlc::Stream::ReadArray;
152+
using dmlc::Stream::Write;
153+
using dmlc::Stream::WriteArray;
154+
155+
// Unused here, implemented for microTVM framing layer.
156+
void MessageStart(uint64_t packet_nbytes) {}
157+
void MessageDone() {}
158+
159+
// Internal supporting.
160+
// Override methods that inherited from dmlc::Stream.
161+
private:
162+
size_t Read(void* data, size_t size) final {
163+
ICHECK_EQ(sock_.RecvAll(data, size), size);
164+
return size;
165+
}
166+
void Write(const void* data, size_t size) final { ICHECK_EQ(sock_.SendAll(data, size), size); }
167+
168+
// Things of current class.
169+
private:
170+
support::TCPSocket sock_;
171+
};
172+
173+
TVM_REGISTER_GLOBAL("rpc.ReturnException").set_body_typed([](int sockfd, String msg) {
174+
auto handler = SimpleSockHandler(sockfd);
175+
RPCReference::ReturnException(msg.c_str(), &handler);
176+
return;
177+
});
178+
145179
} // namespace runtime
146180
} // namespace tvm

tests/python/unittest/test_runtime_rpc.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,3 +606,34 @@ def test_rpc_tracker_via_proxy(device_key):
606606
server1.terminate()
607607
proxy_server.terminate()
608608
tracker_server.terminate()
609+
610+
611+
@tvm.testing.requires_rpc
612+
@pytest.mark.parametrize("with_proxy", (True, False))
613+
def test_rpc_session_timeout_error(with_proxy):
614+
port = 9000
615+
port_end = 10000
616+
617+
tracker = Tracker(port=port, port_end=port_end)
618+
time.sleep(0.5)
619+
tracker_addr = (tracker.host, tracker.port)
620+
621+
if with_proxy:
622+
proxy = Proxy(host="0.0.0.0", port=port, port_end=port_end, tracker_addr=tracker_addr)
623+
time.sleep(0.5)
624+
server = rpc.Server(host=proxy.host, port=proxy.port, is_proxy=True, key="x1")
625+
else:
626+
server = rpc.Server(port=port, port_end=port_end, tracker_addr=tracker_addr, key="x1")
627+
time.sleep(0.5)
628+
629+
rpc_sess = rpc.connect_tracker(*tracker_addr).request(key="x1", session_timeout=1)
630+
631+
with pytest.raises(tvm.error.RPCSessionTimeoutError):
632+
f1 = rpc_sess.get_function("rpc.test.addone")
633+
time.sleep(2)
634+
f1(10)
635+
636+
server.terminate()
637+
if with_proxy:
638+
proxy.terminate()
639+
tracker.terminate()

0 commit comments

Comments
 (0)