Skip to content

Commit 3cc95a4

Browse files
authored
Remove server sync code and combine with async code. (#1092)
1 parent b6dc63e commit 3cc95a4

File tree

6 files changed

+114
-83
lines changed

6 files changed

+114
-83
lines changed

examples/server_async.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,10 @@ def get_commandline():
285285

286286
# set defaults
287287
comm_defaults = {
288-
"tcp": ("socket", 5020),
289-
"udp": ("socket", 5020),
290-
"serial": ("rtu", "/dev/ptyp0"),
291-
"tls": ("tls", 5020),
288+
"tcp": ["socket", 5020],
289+
"udp": ["socket", 5020],
290+
"serial": ["rtu", "/dev/ptyp0"],
291+
"tls": ["tls", 5020],
292292
}
293293
framers = {
294294
"ascii": ModbusAsciiFramer,

examples/server_sync.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,10 @@ def get_commandline():
286286

287287
# set defaults
288288
comm_defaults = {
289-
"tcp": ("socket", 5020),
290-
"udp": ("socket", 5020),
291-
"serial": ("rtu", "/dev/ptyp0"),
292-
"tls": ("tls", 5020),
289+
"tcp": ["socket", 5020],
290+
"udp": ["socket", 5020],
291+
"serial": ["rtu", "/dev/ptyp0"],
292+
"tls": ["tls", 5020],
293293
}
294294
framers = {
295295
"ascii": ModbusAsciiFramer,

pymodbus/client/tcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ async def _connect(self):
8383
_logger.debug("Connecting.")
8484
try:
8585
transport, protocol = await self.loop.create_connection(
86-
self._create_protocol, self.params.host, self.params.port
86+
self._create_protocol, host=self.params.host, port=self.params.port
8787
)
8888
return transport, protocol
8989
except Exception as exc: # pylint: disable=broad-except

pymodbus/server/async_io.py

Lines changed: 74 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import traceback
77
import ssl
8+
from time import sleep
89

910
import serial
1011
from serial_asyncio import create_serial_connection
@@ -32,7 +33,8 @@
3233
# --------------------------------------------------------------------------- #
3334
# Allow access to server object, to e.g. make a shutdown
3435
# --------------------------------------------------------------------------- #
35-
ServerObject = None # pylint: disable=invalid-name
36+
_server_stopped = None # pylint: disable=invalid-name
37+
_server_stop = None # pylint: disable=invalid-name
3638

3739

3840
def sslctx_provider(
@@ -551,11 +553,15 @@ async def serve_forever(self):
551553
try:
552554
await self.server.serve_forever()
553555
except asyncio.exceptions.CancelledError:
554-
pass
556+
raise
557+
except Exception as exc: # pylint: disable=broad-except
558+
txt = f"Server unexpected exception {exc}"
559+
_logger.error(txt)
555560
else:
556561
raise RuntimeError(
557562
"Can't call serve_forever on an already running server object"
558563
)
564+
_logger.info("Server graceful shutdown.")
559565

560566
async def shutdown(self):
561567
"""Shutdown server."""
@@ -892,6 +898,32 @@ async def serve_forever(self):
892898
# Creation Factories
893899
# --------------------------------------------------------------------------- #
894900

901+
async def _helper_run_server(server, custom_functions):
902+
"""Help starting/stopping server."""
903+
global _server_stopped, _server_stop # pylint: disable=global-statement,invalid-name
904+
905+
for func in custom_functions:
906+
server.decoder.register(func)
907+
_server_stopped = asyncio.Event()
908+
_server_stop = asyncio.Event()
909+
try:
910+
server_task = asyncio.create_task(server.serve_forever())
911+
except Exception as exc: # pylint: disable=broad-except
912+
txt = f"Server caught exception: {exc}"
913+
_logger.error(txt)
914+
await _server_stop.wait()
915+
await server.shutdown()
916+
server_task.cancel()
917+
owntask = asyncio.current_task()
918+
for task in asyncio.all_tasks():
919+
if task != owntask:
920+
task.cancel()
921+
try:
922+
await task
923+
except asyncio.CancelledError:
924+
pass
925+
_server_stopped.set()
926+
895927

896928
async def StartAsyncTcpServer( # pylint: disable=invalid-name,dangerous-default-value
897929
context=None,
@@ -914,17 +946,18 @@ async def StartAsyncTcpServer( # pylint: disable=invalid-name,dangerous-default
914946
:param kwargs: The rest
915947
:return: an initialized but inactive server object coroutine
916948
"""
917-
global ServerObject # pylint: disable=global-statement
918-
919949
framer = kwargs.pop("framer", ModbusSocketFramer)
920-
ServerObject = ModbusTcpServer(context, framer, identity, address, **kwargs)
921-
922-
for func in custom_functions:
923-
ServerObject.decoder.register(func) # pragma: no cover
950+
server = ModbusTcpServer(
951+
context,
952+
framer,
953+
identity,
954+
address,
955+
**kwargs
956+
)
924957

925958
if defer_start:
926-
return ServerObject
927-
await ServerObject.serve_forever()
959+
return server
960+
await _helper_run_server(server, custom_functions)
928961

929962

930963
async def StartAsyncTlsServer( # pylint: disable=invalid-name,dangerous-default-value,too-many-arguments
@@ -963,10 +996,8 @@ async def StartAsyncTlsServer( # pylint: disable=invalid-name,dangerous-default
963996
:param kwargs: The rest
964997
:return: an initialized but inactive server object coroutine
965998
"""
966-
global ServerObject # pylint: disable=global-statement
967-
968999
framer = kwargs.pop("framer", ModbusTlsFramer)
969-
ServerObject = ModbusTlsServer(
1000+
server = ModbusTlsServer(
9701001
context,
9711002
framer,
9721003
identity,
@@ -980,13 +1011,9 @@ async def StartAsyncTlsServer( # pylint: disable=invalid-name,dangerous-default
9801011
allow_reuse_port=allow_reuse_port,
9811012
**kwargs,
9821013
)
983-
984-
for func in custom_functions:
985-
ServerObject.decoder.register(func) # pragma: no cover
986-
9871014
if defer_start:
988-
return ServerObject
989-
await ServerObject.serve_forever()
1015+
return server
1016+
await _helper_run_server(server, custom_functions)
9901017

9911018

9921019
async def StartAsyncUdpServer( # pylint: disable=invalid-name,dangerous-default-value
@@ -1009,17 +1036,17 @@ async def StartAsyncUdpServer( # pylint: disable=invalid-name,dangerous-default
10091036
up without the ability to shut it off
10101037
:param kwargs:
10111038
"""
1012-
global ServerObject # pylint: disable=global-statement
1013-
10141039
framer = kwargs.pop("framer", ModbusSocketFramer)
1015-
ServerObject = ModbusUdpServer(context, framer, identity, address, **kwargs)
1016-
1017-
for func in custom_functions:
1018-
ServerObject.decoder.register(func) # pragma: no cover
1019-
1040+
server = ModbusUdpServer(
1041+
context,
1042+
framer,
1043+
identity,
1044+
address,
1045+
**kwargs
1046+
)
10201047
if defer_start:
1021-
return ServerObject
1022-
await ServerObject.serve_forever()
1048+
return server
1049+
await _helper_run_server(server, custom_functions)
10231050

10241051

10251052
async def StartAsyncSerialServer( # pylint: disable=invalid-name,dangerous-default-value
@@ -1040,17 +1067,17 @@ async def StartAsyncSerialServer( # pylint: disable=invalid-name,dangerous-defa
10401067
up without the ability to shut it off
10411068
:param kwargs: The rest
10421069
"""
1043-
global ServerObject # pylint: disable=global-statement
1044-
10451070
framer = kwargs.pop("framer", ModbusAsciiFramer)
1046-
ServerObject = ModbusSerialServer(context, framer, identity=identity, **kwargs)
1047-
for func in custom_functions:
1048-
ServerObject.decoder.register(func)
1049-
1071+
server = ModbusSerialServer(
1072+
context,
1073+
framer,
1074+
identity=identity,
1075+
**kwargs
1076+
)
10501077
if defer_start:
1051-
return ServerObject
1052-
await ServerObject.start()
1053-
await ServerObject.serve_forever()
1078+
return server
1079+
await server.start()
1080+
await _helper_run_server(server, custom_functions)
10541081

10551082

10561083
def StartSerialServer(**kwargs): # pylint: disable=invalid-name
@@ -1075,13 +1102,18 @@ def StartUdpServer(**kwargs): # pylint: disable=invalid-name
10751102

10761103
async def ServerAsyncStop(): # pylint: disable=invalid-name
10771104
"""Terminate server."""
1078-
global ServerObject # pylint: disable=global-statement,invalid-name
1105+
global _server_stopped, _server_stop # pylint: disable=invalid-name,global-variable-not-assigned
10791106

1080-
if ServerObject:
1081-
await ServerObject.shutdown()
1082-
ServerObject = None
1107+
_server_stop.set()
1108+
try:
1109+
await _server_stopped.wait()
1110+
except asyncio.exceptions.CancelledError:
1111+
pass
10831112

10841113

10851114
def ServerStop(): # pylint: disable=invalid-name
10861115
"""Terminate server."""
1087-
asyncio.run(ServerAsyncStop())
1116+
global _server_stopped, _server_stop # pylint: disable=invalid-name,global-variable-not-assigned
1117+
1118+
_server_stop.set()
1119+
sleep(10)

test/test_examples.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from time import sleep
77
import logging
88

9-
from unittest.mock import patch, MagicMock
109
import pytest
1110
import pytest_asyncio
1211

@@ -30,11 +29,11 @@
3029
_logger.setLevel("DEBUG")
3130

3231
TEST_COMMS_FRAMER = [
33-
("tcp", ModbusSocketFramer, 5020),
34-
("tcp", ModbusRtuFramer, 5021),
35-
("tls", ModbusTlsFramer, 5030),
36-
("udp", ModbusSocketFramer, 5040),
37-
("udp", ModbusRtuFramer, 5041),
32+
("tcp", ModbusSocketFramer, 5021),
33+
("tcp", ModbusRtuFramer, 5022),
34+
("tls", ModbusTlsFramer, 5023),
35+
("udp", ModbusSocketFramer, 5024),
36+
("udp", ModbusRtuFramer, 5025),
3837
("serial", ModbusRtuFramer, "dummy"),
3938
("serial", ModbusAsciiFramer, "dummy"),
4039
("serial", ModbusBinaryFramer, "dummy"),
@@ -52,34 +51,25 @@ class Commandline:
5251
slaves = None
5352

5453

55-
@pytest_asyncio.fixture(name="mock_libs")
56-
def _helper_libs():
57-
"""Patch ssl and pyserial-async libs."""
58-
with patch('pymodbus.server.async_io.create_serial_connection') as mock_serial:
59-
mock_serial.return_value = (MagicMock(), MagicMock())
60-
yield True
61-
62-
6354
@pytest_asyncio.fixture(name="mock_run_server")
64-
async def _helper_server( # pylint: disable=unused-argument
65-
mock_libs,
55+
async def _helper_server(
6656
test_comm,
6757
test_framer,
58+
test_port_offset,
6859
test_port,
6960
):
7061
"""Run server."""
62+
if test_comm in ("serial"):
63+
yield
64+
return
7165
args = Commandline
7266
args.comm = test_comm
7367
args.framer = test_framer
74-
args.port = test_port
68+
args.port = test_port + test_port_offset
7569
asyncio.create_task(run_async_server(args))
7670
await asyncio.sleep(0.1)
77-
yield True
71+
yield
7872
await ServerAsyncStop()
79-
tasks = asyncio.all_tasks()
80-
owntask = asyncio.current_task()
81-
for i in [i for i in tasks if not (i.done() or i.cancelled() or i == owntask)]:
82-
i.cancel()
8373

8474

8575
async def run_client(
@@ -98,34 +88,43 @@ async def run_client(
9888
await asyncio.sleep(0.1)
9989

10090

91+
@pytest.mark.parametrize("test_port_offset", [10])
10192
@pytest.mark.parametrize("test_comm, test_framer, test_port", TEST_COMMS_FRAMER)
10293
async def test_exp_async_simple( # pylint: disable=unused-argument
10394
test_comm,
10495
test_framer,
96+
test_port_offset,
10597
test_port,
10698
mock_run_server,
10799
):
108100
"""Run async client and server."""
109101

110102

103+
@pytest.mark.parametrize("test_port_offset", [20])
111104
@pytest.mark.parametrize("test_comm, test_framer, test_port", TEST_COMMS_FRAMER)
112-
def test_exp_sync_simple( # pylint: disable=unused-argument
113-
mock_libs,
105+
def test_exp_sync_simple(
114106
test_comm,
115107
test_framer,
108+
test_port_offset,
116109
test_port,
117110
):
118111
"""Run sync client and server."""
112+
if test_comm == "serial":
113+
# missing mock of port
114+
return
119115
args = Commandline
120116
args.comm = test_comm
121-
args.port = test_port
117+
args.port = test_port + test_port_offset
118+
args.framer = test_framer
122119
thread = Thread(target=run_sync_server, args=(args,))
123120
thread.daemon = True
124121
thread.start()
125-
sleep(0.1)
122+
sleep(1)
126123
ServerStop()
124+
_logger.error("jan igen")
127125

128126

127+
@pytest.mark.parametrize("test_port_offset", [30])
129128
@pytest.mark.parametrize("test_comm, test_framer, test_port", TEST_COMMS_FRAMER)
130129
@pytest.mark.parametrize(
131130
"test_type",
@@ -139,6 +138,7 @@ def test_exp_sync_simple( # pylint: disable=unused-argument
139138
async def test_exp_async_framer( # pylint: disable=unused-argument
140139
test_comm,
141140
test_framer,
141+
test_port_offset,
142142
test_port,
143143
mock_run_server,
144144
test_type
@@ -147,11 +147,10 @@ async def test_exp_async_framer( # pylint: disable=unused-argument
147147
if test_type == run_async_ext_calls and test_framer == ModbusRtuFramer: # pylint: disable=comparison-with-callable
148148
return
149149
if test_comm == "serial":
150-
# mocking serial needs to pass data between send/receive
151150
return
152151

153152
args = Commandline
154153
args.framer = test_framer
155154
args.comm = test_comm
156-
args.port = test_port
155+
args.port = test_port + test_port_offset
157156
await run_client(test_comm, test_type, args=args)

0 commit comments

Comments
 (0)