@@ -12,43 +12,59 @@ class IPConnectionSettings:
1212 port : int = 25565
1313
1414
15- class IPConnection :
16- def __init__ (self ):
17- self ._reader , self ._writer = (None , None )
15+ @dataclass
16+ class StreamConnection :
17+ reader : asyncio .StreamReader
18+ writer : asyncio .StreamWriter
19+
20+ def __post_init__ (self ):
1821 self ._lock = asyncio .Lock ()
1922
20- async def connect (self , settings : IPConnectionSettings ):
21- self ._reader , self ._writer = await asyncio .open_connection (
22- settings .ip , settings .port
23- )
23+ async def __aenter__ (self ):
24+ await self ._lock .acquire ()
25+ return self
2426
25- def ensure_connected (self ):
26- if self ._reader is None or self ._writer is None :
27+ async def __aexit__ (self , exc_type , exc_val , exc_tb ):
28+ self ._lock .release ()
29+
30+ async def send_message (self , message ) -> None :
31+ self .writer .write (message .encode ("utf-8" ))
32+ await self .writer .drain ()
33+
34+ async def receive_response (self ) -> str :
35+ data = await self .reader .readline ()
36+ return data .decode ("utf-8" )
37+
38+ async def close (self ):
39+ self .writer .close ()
40+ await self .writer .wait_closed ()
41+
42+
43+ class IPConnection :
44+ def __init__ (self ):
45+ self .__connection = None
46+
47+ @property
48+ def _connection (self ) -> StreamConnection :
49+ if self .__connection is None :
2750 raise DisconnectedError ("Need to call connect() before using IPConnection." )
2851
52+ return self .__connection
53+
54+ async def connect (self , settings : IPConnectionSettings ):
55+ reader , writer = await asyncio .open_connection (settings .ip , settings .port )
56+ self .__connection = StreamConnection (reader , writer )
57+
2958 async def send_command (self , message ) -> None :
30- async with self ._lock :
31- self .ensure_connected ()
32- await self ._send_message (message )
59+ async with self ._connection as connection :
60+ await connection .send_message (message )
3361
3462 async def send_query (self , message ) -> str :
35- async with self ._lock :
36- self .ensure_connected ()
37- await self ._send_message (message )
38- return await self ._receive_response ()
63+ async with self ._connection as connection :
64+ await connection .send_message (message )
65+ return await connection .receive_response ()
3966
40- # TODO: Figure out type hinting for connections. TypeGuard fails to work as expected
4167 async def close (self ):
42- async with self ._lock :
43- self .ensure_connected ()
44- self ._writer .close ()
45- await self ._writer .wait_closed ()
46- self ._reader , self ._writer = (None , None )
47-
48- async def _send_message (self , message ) -> None :
49- self ._writer .write (message .encode ("utf-8" ))
50- await self ._writer .drain ()
51-
52- async def _receive_response (self ) -> str :
53- data = await self ._reader .readline ()
54- return data .decode ("utf-8" )
68+ async with self ._connection as connection :
69+ await connection .close ()
70+ self .__connection = None
0 commit comments