|
1 | 1 | import asyncio |
2 | 2 | import functools |
3 | | -import logging |
4 | 3 | from collections import defaultdict |
5 | 4 | from typing import List |
6 | 5 |
|
| 6 | +from haproxyspoa.logging import logger, FlowIdLoggerAdapter |
7 | 7 | from haproxyspoa.payloads.ack import AckPayload |
8 | 8 | from haproxyspoa.payloads.agent_disconnect import DisconnectStatusCode, AgentDisconnectPayload |
9 | 9 | from haproxyspoa.payloads.agent_hello import AgentHelloPayload, AgentCapabilities |
|
12 | 12 | from haproxyspoa.payloads.notify import NotifyPayload |
13 | 13 | from haproxyspoa.spoa_frame import Frame, AgentHelloFrame, FrameType |
14 | 14 |
|
| 15 | +import secrets |
15 | 16 |
|
16 | | -class SpoaServer: |
17 | | - |
18 | | - def __init__(self): |
19 | | - self.handlers = defaultdict(list) |
20 | | - |
21 | | - def handler(self, message_key: str): |
22 | | - def _handler(fn): |
23 | | - @functools.wraps(fn) |
24 | | - def wrapper(*args, **kwargs): |
25 | | - return fn(*args, **kwargs) |
26 | | - self.handlers[message_key].append(wrapper) |
27 | | - return wrapper |
28 | | - return _handler |
29 | | - |
30 | | - async def handle_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): |
31 | | - haproxy_hello_frame = await Frame.read_frame(reader) |
32 | | - if not haproxy_hello_frame.headers.is_haproxy_hello(): |
33 | | - await self.send_agent_disconnect(writer) |
34 | | - return |
35 | | - await self.handle_hello_handshake(haproxy_hello_frame, writer) |
36 | | - |
37 | | - if HaproxyHelloPayload(haproxy_hello_frame.payload).healthcheck(): |
38 | | - logging.info("It is just a health check, immediately disconnecting") |
39 | | - return |
40 | 17 |
|
41 | | - logging.info("Completed new handshake with Haproxy") |
| 18 | +class SpoaConnection: |
| 19 | + |
| 20 | + def __init__(self, writer: asyncio.StreamWriter, handlers): |
| 21 | + self.logger = FlowIdLoggerAdapter(logger, {"flow_id": secrets.token_hex(4)}) |
| 22 | + self.handlers = handlers |
| 23 | + self.writer = writer |
42 | 24 |
|
43 | | - while True: |
44 | | - frame = await Frame.read_frame(reader) |
45 | | - |
46 | | - if frame.headers.is_haproxy_disconnect(): |
47 | | - await self.handle_haproxy_disconnect(frame) |
48 | | - await self.send_agent_disconnect(writer) |
49 | | - return |
50 | | - elif frame.headers.is_haproxy_notify(): |
51 | | - await self.handle_haproxy_notify(frame, writer) |
52 | | - |
53 | | - async def handle_haproxy_notify(self, frame: Frame, writer: asyncio.StreamWriter): |
| 25 | + async def handle_haproxy_notify(self, frame: Frame): |
| 26 | + self.logger.debug("Incoming `notify` frame from HAProxy") |
54 | 27 | notify_payload = NotifyPayload(frame.payload) |
55 | 28 |
|
56 | 29 | response_futures = [] |
57 | 30 | for msg_key, msg_val in notify_payload.messages.items(): |
| 31 | + self.logger.info(f"Received request on key '{msg_key}'") |
58 | 32 | for handler in self.handlers[msg_key]: |
59 | 33 | response_futures.append(handler(**notify_payload.messages[msg_key])) |
60 | 34 |
|
| 35 | + self.logger.info(f"Found {len(response_futures)} matching handlers, awaiting response...") |
61 | 36 | ack_payloads: List[AckPayload] = await asyncio.gather(*response_futures) |
62 | 37 | ack = AckPayload.create_from_all(*ack_payloads) |
| 38 | + payload = ack.to_bytes() |
| 39 | + |
| 40 | + self.logger.info(f"Responding with combined payload of {len(payload.getbuffer())} bytes") |
63 | 41 |
|
64 | 42 | ack_frame = Frame( |
65 | 43 | frame_type=FrameType.ACK, |
66 | 44 | stream_id=frame.headers.stream_id, |
67 | 45 | frame_id=frame.headers.frame_id, |
68 | 46 | flags=1, |
69 | | - payload=ack.to_bytes() |
| 47 | + payload=payload |
70 | 48 | ) |
71 | | - await ack_frame.write_frame(writer) |
| 49 | + await ack_frame.write_frame(self.writer) |
72 | 50 |
|
73 | | - async def send_agent_disconnect(self, writer: asyncio.StreamWriter): |
| 51 | + async def send_agent_disconnect(self): |
| 52 | + self.logger.info("Agent is now dropping connection") |
74 | 53 | disconnect_frame = Frame( |
75 | 54 | frame_type=FrameType.AGENT_DISCONNECT, |
76 | 55 | flags=1, |
77 | 56 | stream_id=0, |
78 | 57 | frame_id=0, |
79 | 58 | payload=AgentDisconnectPayload().to_buffer() |
80 | 59 | ) |
81 | | - await disconnect_frame.write_frame(writer) |
| 60 | + await disconnect_frame.write_frame(self.writer) |
82 | 61 |
|
83 | 62 | async def handle_haproxy_disconnect(self, frame: Frame): |
84 | 63 | payload = HaproxyDisconnectPayload(frame.payload) |
85 | 64 | if payload.status_code() != DisconnectStatusCode.NORMAL: |
86 | | - logging.info(f"Haproxy is disconnecting us with status code {payload.status_code()} - `{payload.message()}`") |
| 65 | + self.logger.info(f"Haproxy is disconnecting us with status code {payload.status_code()} - `{payload.message()}`") |
87 | 66 |
|
88 | | - async def handle_hello_handshake(self, frame: Frame, writer: asyncio.StreamWriter): |
| 67 | + async def handle_hello_handshake(self, frame: Frame): |
| 68 | + capabilities = AgentCapabilities() |
| 69 | + self.logger.info(f"Received `hello handshake`, responding with agent capabilities of: '{capabilities}'") |
89 | 70 | agent_hello_frame = AgentHelloFrame( |
90 | 71 | payload=AgentHelloPayload( |
91 | | - capabilities=AgentCapabilities() |
| 72 | + capabilities=capabilities, |
92 | 73 | ), |
93 | 74 | stream_id=frame.headers.stream_id, |
94 | 75 | frame_id=frame.headers.frame_id, |
95 | 76 | ) |
96 | | - await agent_hello_frame.write_frame(writer) |
| 77 | + await agent_hello_frame.write_frame(self.writer) |
| 78 | + |
| 79 | + |
| 80 | +class SpoaServer: |
| 81 | + |
| 82 | + def __init__(self): |
| 83 | + self.handlers = defaultdict(list) |
| 84 | + |
| 85 | + def handler(self, message_key: str): |
| 86 | + def _handler(fn): |
| 87 | + @functools.wraps(fn) |
| 88 | + def wrapper(*args, **kwargs): |
| 89 | + return fn(*args, **kwargs) |
| 90 | + self.handlers[message_key].append(wrapper) |
| 91 | + return wrapper |
| 92 | + return _handler |
| 93 | + |
| 94 | + async def handle_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): |
| 95 | + conn = SpoaConnection(writer, self.handlers) |
| 96 | + |
| 97 | + haproxy_hello_frame = await Frame.read_frame(reader) |
| 98 | + |
| 99 | + if not haproxy_hello_frame.headers.is_haproxy_hello(): |
| 100 | + conn.logger.error(f""" |
| 101 | + Expected a `hello` frame from HAProxy, |
| 102 | + but received unexpected frame of type {haproxy_hello_frame.headers.frame_type} |
| 103 | + """.strip()) |
| 104 | + await conn.send_agent_disconnect() |
| 105 | + return |
| 106 | + await conn.handle_hello_handshake(haproxy_hello_frame) |
| 107 | + |
| 108 | + if HaproxyHelloPayload(haproxy_hello_frame.payload).healthcheck(): |
| 109 | + conn.logger.info("Health check, immediately disconnecting") |
| 110 | + return |
| 111 | + |
| 112 | + while True: |
| 113 | + frame = await Frame.read_frame(reader) |
| 114 | + |
| 115 | + if frame.headers.is_haproxy_disconnect(): |
| 116 | + await conn.handle_haproxy_disconnect(frame) |
| 117 | + await conn.send_agent_disconnect() |
| 118 | + return |
| 119 | + elif frame.headers.is_haproxy_notify(): |
| 120 | + await conn.handle_haproxy_notify(frame) |
97 | 121 |
|
98 | 122 | async def _run(self, host: str = "0.0.0.0", port: int = 9002): |
99 | 123 | server = await asyncio.start_server(self.handle_connection, host=host, port=port, ) |
| 124 | + logger.info(f"HAProxy SPO Agent listening at {host}:{port}") |
100 | 125 | await server.serve_forever() |
101 | 126 |
|
102 | 127 | def run(self, *args, **kwargs): |
|
0 commit comments