Skip to content

Commit 5b49f2c

Browse files
authored
Support output to webhook (#4598)
1 parent 4785a74 commit 5b49f2c

File tree

14 files changed

+587
-351
lines changed

14 files changed

+587
-351
lines changed

backend/app/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def main() -> None:
118118
app,
119119
host=settings.host,
120120
port=settings.port,
121-
reload=settings.debug and settings.environment == "dev",
121+
# FIXME: reload mode currently does not work with multiple workers
122+
# reload=settings.debug and settings.environment == "dev",
122123
log_level="debug" if settings.debug else "info",
123124
)
124125

backend/app/schemas/sink.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from enum import StrEnum
5+
from os import getenv
56
from typing import Annotated, Literal
67

78
from pydantic import Field, TypeAdapter
89

910
from app.schemas.base import BaseIDNameModel
1011

12+
MQTT_USERNAME = "MQTT_USERNAME"
13+
MQTT_PASSWORD = "MQTT_PASSWORD" # noqa: S105
14+
1115

1216
class SinkType(StrEnum):
1317
DISCONNECTED = "disconnected"
@@ -67,8 +71,7 @@ class MqttSinkConfig(BaseSinkConfig):
6771
broker_host: str
6872
broker_port: int
6973
topic: str
70-
username: str | None = None
71-
password: str | None = None
74+
auth_required: bool = False
7275

7376
model_config = {
7477
"json_schema_extra": {
@@ -80,10 +83,24 @@ class MqttSinkConfig(BaseSinkConfig):
8083
"broker_port": 1883,
8184
"topic": "predictions",
8285
"output_formats": ["predictions"],
86+
"auth_required": True,
8387
}
8488
}
8589
}
8690

91+
def get_credentials(self) -> tuple[str | None, str | None]:
92+
"""Configure stream URL with authentication if required."""
93+
if not self.auth_required:
94+
return None, None
95+
96+
username = getenv(MQTT_USERNAME)
97+
password = getenv(MQTT_PASSWORD)
98+
99+
if not username or not password:
100+
raise RuntimeError("MQTT credentials not provided.")
101+
102+
return username, password
103+
87104

88105
class RosSinkConfig(BaseSinkConfig):
89106
sink_type: Literal[SinkType.ROS]
@@ -102,9 +119,16 @@ class RosSinkConfig(BaseSinkConfig):
102119
}
103120

104121

122+
HttpMethod = Literal["POST", "PUT", "PATCH"]
123+
HttpHeaders = dict[str, str]
124+
125+
105126
class WebhookSinkConfig(BaseSinkConfig):
106127
sink_type: Literal[SinkType.WEBHOOK]
107128
webhook_url: str
129+
http_method: HttpMethod = "POST"
130+
headers: HttpHeaders | None = None
131+
timeout: int = 10 # seconds
108132

109133
model_config = {
110134
"json_schema_extra": {
@@ -113,6 +137,8 @@ class WebhookSinkConfig(BaseSinkConfig):
113137
"sink_type": "webhook",
114138
"name": "Webhook Endpoint",
115139
"webhook_url": "https://example.com/webhook",
140+
"http_method": "PUT",
141+
"headers": {"Authorization": "Bearer YOUR_TOKEN"},
116142
"output_formats": ["predictions"],
117143
}
118144
}

backend/app/services/dispatch_service.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Callable, Sequence
55

66
from app.schemas import Sink, SinkType
7-
from app.services.dispatchers import Dispatcher, FolderDispatcher, MqttDispatcher
7+
from app.services.dispatchers import Dispatcher, FolderDispatcher, MqttDispatcher, WebhookDispatcher
88

99

1010
class DispatchService:
@@ -13,11 +13,11 @@ class DispatchService:
1313
SinkType.FOLDER: lambda config: FolderDispatcher(output_config=config), # type: ignore[union-attr, arg-type]
1414
SinkType.MQTT: lambda config: MqttDispatcher(output_config=config), # type: ignore[union-attr, arg-type]
1515
SinkType.ROS: lambda _: _raise_not_implemented("ROS output is not implemented yet"),
16-
SinkType.WEBHOOK: lambda _: _raise_not_implemented("WEBHOOK output is not implemented yet"),
16+
SinkType.WEBHOOK: lambda config: WebhookDispatcher(output_config=config), # type: ignore[union-attr, arg-type]
1717
}
1818

1919
@classmethod
20-
def get_destination(cls, output_config: Sink) -> Dispatcher | None:
20+
def _get_destination(cls, output_config: Sink) -> Dispatcher | None:
2121
# TODO handle exceptions: if some output cannot be initialized, exclude it and raise a warning
2222
factory = cls._dispatcher_registry.get(output_config.sink_type)
2323
if factory is None:
@@ -33,7 +33,7 @@ def get_destinations(cls, output_configs: Sequence[Sink]) -> list[Dispatcher]:
3333
Args:
3434
output_configs (Sequence[OutputConfig]): A sequence of output configurations.
3535
"""
36-
return [dispatcher for config in output_configs if (dispatcher := cls.get_destination(config)) is not None]
36+
return [dispatcher for config in output_configs if (dispatcher := cls._get_destination(config)) is not None]
3737

3838

3939
def _raise_not_implemented(message: str) -> None:

backend/app/services/dispatchers/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from .filesystem import FolderDispatcher
55
from .mqtt import MqttDispatcher
6+
from .webhook import WebhookDispatcher
67

7-
Dispatcher = FolderDispatcher | MqttDispatcher
8+
Dispatcher = FolderDispatcher | MqttDispatcher | WebhookDispatcher
89

9-
__all__ = ["Dispatcher", "FolderDispatcher", "MqttDispatcher"]
10+
__all__ = ["Dispatcher", "FolderDispatcher", "MqttDispatcher", "WebhookDispatcher"]

backend/app/services/dispatchers/base.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
# Copyright (C) 2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import base64
45
import time
56
from abc import ABCMeta, abstractmethod
7+
from datetime import datetime
8+
from typing import Any
69

10+
import cv2
711
import numpy as np
812
from model_api.models.result import Result
913

10-
from app.schemas import Sink
14+
from app.schemas import OutputFormat, Sink
15+
16+
17+
def numpy_to_base64(image: np.ndarray, fmt: str = ".jpg") -> str:
18+
"""Convert a numpy array image to a base64 string."""
19+
success, img_buf = cv2.imencode(fmt, image)
20+
if success:
21+
return base64.b64encode(img_buf).decode("utf-8")
22+
raise ValueError(f"Failed to encode image in format {fmt}")
1123

1224

1325
class DispatchError(Exception):
@@ -51,6 +63,30 @@ def _dispatch(
5163
predictions (Result): Predictions generated by the model
5264
"""
5365

66+
def _create_payload(
67+
self, original_image: np.ndarray, image_with_visualization: np.ndarray, predictions: Result
68+
) -> dict[str, Any]:
69+
"""Create a JSON payload with the requested output formats."""
70+
result: dict[str, Any] = {}
71+
payload = {"timestamp": datetime.now().isoformat(), "result": result}
72+
73+
if OutputFormat.IMAGE_ORIGINAL in self.output_formats:
74+
result[OutputFormat.IMAGE_ORIGINAL] = {
75+
"data": numpy_to_base64(original_image),
76+
"format": "jpeg",
77+
}
78+
79+
if OutputFormat.IMAGE_WITH_PREDICTIONS in self.output_formats:
80+
result[OutputFormat.IMAGE_WITH_PREDICTIONS] = {
81+
"data": numpy_to_base64(image_with_visualization),
82+
"format": "jpeg",
83+
}
84+
85+
if OutputFormat.PREDICTIONS in self.output_formats:
86+
result[OutputFormat.PREDICTIONS] = str(predictions)
87+
88+
return payload
89+
5490
def dispatch(
5591
self,
5692
original_image: np.ndarray,

backend/app/services/dispatchers/mqtt.py

Lines changed: 19 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
# Copyright (C) 2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import base64
54
import json
65
import logging
76
import threading
87
import time
9-
from datetime import datetime
108
from typing import Any
119

12-
import cv2
1310
import numpy as np
1411
from model_api.models.result import Result
1512

16-
from app.schemas.sink import MqttSinkConfig, OutputFormat
13+
from app.schemas.sink import MqttSinkConfig
1714
from app.services.dispatchers.base import BaseDispatcher
1815

1916
try:
@@ -28,17 +25,6 @@
2825
CONNECT_TIMEOUT = 10
2926

3027

31-
def _encode_image_to_base64(image: np.ndarray, fmt: str = ".jpg") -> str:
32-
success, img_buf = cv2.imencode(fmt, image)
33-
if success:
34-
return base64.b64encode(img_buf.tobytes()).decode("utf-8")
35-
raise ValueError(f"Failed to encode image in format {fmt}")
36-
37-
38-
def _create_mqtt_payload(data_type: str, **kwargs) -> dict[str, Any]:
39-
return {"timestamp": datetime.now().isoformat(), "type": data_type, **kwargs}
40-
41-
4228
class MqttDispatcher(BaseDispatcher):
4329
def __init__(
4430
self,
@@ -65,8 +51,7 @@ def __init__(
6551
self.broker_host = output_config.broker_host
6652
self.broker_port = output_config.broker_port
6753
self.topic = output_config.topic
68-
self.username = output_config.username
69-
self.password = output_config.password
54+
self.username, self.password = output_config.get_credentials()
7055

7156
self._connected = False
7257
self._connection_lock = threading.Lock()
@@ -82,7 +67,7 @@ def _create_default_client(self) -> "mqtt.Client":
8267
client = mqtt.Client(client_id=client_id)
8368
client.on_connect = self._on_connect
8469
client.on_disconnect = self._on_disconnect
85-
if self.username and self.password:
70+
if self.username is not None and self.password is not None:
8671
client.username_pw_set(self.username, self.password)
8772
return client
8873

@@ -119,54 +104,26 @@ def _on_disconnect(self, _client: "mqtt.Client", _userdata: Any, rc: int):
119104
def is_connected(self) -> bool:
120105
return self._connected
121106

122-
def _publish_message(self, topic: str, payload: dict[str, Any]) -> bool:
107+
def __publish_message(self, topic: str, payload: dict[str, Any]) -> None:
123108
if not self._connected:
124109
logger.warning("Client not connected. Reconnecting...")
125110
try:
126111
self._connect()
127-
except Exception:
112+
except ConnectionError:
128113
logger.exception("Reconnect failed")
129-
return False
130114

131115
try:
132116
result = self.client.publish(topic, json.dumps(payload))
133-
if result.rc == mqtt.MQTT_ERR_SUCCESS:
134-
if self._track_messages:
135-
self._published_messages.append({"topic": topic, "payload": payload, "timestamp": datetime.now()})
136-
return True
117+
if result.rc == mqtt.MQTT_ERR_SUCCESS and self._track_messages:
118+
self._published_messages.append({"topic": topic, "payload": payload})
137119
logger.error(f"Publish failed: {mqtt.error_string(result.rc)}")
138-
except Exception:
139-
logger.exception("Publish exception")
140-
return False
141-
142-
def _dispatch_image(self, image: np.ndarray, data_type: str):
143-
try:
144-
image_b64 = _encode_image_to_base64(image)
145-
payload = _create_mqtt_payload(
146-
data_type=data_type,
147-
image=image_b64,
148-
format="jpeg",
149-
)
150-
self._publish_message(self.topic, payload)
151-
except Exception:
152-
logger.exception("Failed to dispatch %s", data_type)
153-
154-
def _dispatch_predictions(self, predictions: Result):
155-
try:
156-
payload = _create_mqtt_payload(data_type=OutputFormat.PREDICTIONS.value, predictions=str(predictions))
157-
self._publish_message(self.topic, payload)
158-
except Exception:
159-
logger.exception("Failed to dispatch predictions")
120+
except ValueError:
121+
logger.exception("Invalid payload for MQTT publish")
160122

161123
def _dispatch(self, original_image: np.ndarray, image_with_visualization: np.ndarray, predictions: Result) -> None:
162-
if OutputFormat.IMAGE_ORIGINAL in self.output_formats:
163-
self._dispatch_image(original_image, OutputFormat.IMAGE_ORIGINAL)
164-
165-
if OutputFormat.IMAGE_WITH_PREDICTIONS in self.output_formats:
166-
self._dispatch_image(image_with_visualization, OutputFormat.IMAGE_WITH_PREDICTIONS)
124+
payload = self._create_payload(original_image, image_with_visualization, predictions)
167125

168-
if OutputFormat.PREDICTIONS in self.output_formats:
169-
self._dispatch_predictions(predictions)
126+
self.__publish_message(self.topic, payload)
170127

171128
def get_published_messages(self) -> list:
172129
return self._published_messages.copy()
@@ -175,11 +132,11 @@ def clear_published_messages(self) -> None:
175132
self._published_messages.clear()
176133

177134
def close(self) -> None:
178-
try:
179-
self.client.loop_stop()
180-
self.client.disconnect()
181-
except Exception:
182-
logger.exception("Error closing dispatcher")
183-
finally:
184-
self._connected = False
185-
self._connection_event.clear()
135+
err = self.client.loop_stop()
136+
if err != mqtt.MQTT_ERR_SUCCESS:
137+
logger.warning(f"Error stopping MQTT loop: {mqtt.error_string(err)}")
138+
err = self.client.disconnect()
139+
if err != mqtt.MQTT_ERR_SUCCESS:
140+
logger.warning(f"Error disconnecting MQTT client: {mqtt.error_string(err)}")
141+
self._connected = False
142+
self._connection_event.clear()
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""This module contains the WebhookDispatcher class for dispatching images and predictions to a webhook endpoint."""
5+
6+
import logging
7+
from typing import Any
8+
9+
import numpy as np
10+
import requests
11+
from model_api.models.result import Result
12+
from requests.adapters import HTTPAdapter
13+
from urllib3.util.retry import Retry
14+
15+
from app.schemas.sink import WebhookSinkConfig
16+
from app.services.dispatchers.base import BaseDispatcher
17+
18+
logger = logging.getLogger(__name__)
19+
20+
MAX_RETRIES = 3
21+
BACKOFF_FACTOR = 0.3
22+
RETRY_ON_STATUS = [500, 502, 503, 504]
23+
24+
25+
class WebhookDispatcher(BaseDispatcher):
26+
def __init__(self, output_config: WebhookSinkConfig) -> None:
27+
"""
28+
Initialize the WebhookDispatcher.
29+
Args:
30+
output_config: Configuration for the webhook-based output destination
31+
"""
32+
super().__init__(output_config=output_config)
33+
self.webhook_url = output_config.webhook_url
34+
self.http_method = output_config.http_method
35+
self.headers = output_config.headers
36+
self.timeout = output_config.timeout
37+
self.session = requests.Session()
38+
retries = Retry(
39+
total=MAX_RETRIES,
40+
backoff_factor=BACKOFF_FACTOR,
41+
status_forcelist=RETRY_ON_STATUS,
42+
allowed_methods=["PATCH", "POST", "PUT"],
43+
)
44+
adapter = HTTPAdapter(max_retries=retries)
45+
self.session.mount("http://", adapter)
46+
self.session.mount("https://", adapter)
47+
48+
def __send_to_webhook(self, payload: dict[str, Any]) -> None:
49+
logger.debug("Sending payload to webhook at %s", self.webhook_url)
50+
response = self.session.request(
51+
self.http_method, self.webhook_url, headers=self.headers, json=payload, timeout=self.timeout
52+
)
53+
response.raise_for_status()
54+
logger.debug("Response from webhook: %s", response.text)
55+
56+
def _dispatch(
57+
self,
58+
original_image: np.ndarray,
59+
image_with_visualization: np.ndarray,
60+
predictions: Result,
61+
) -> None:
62+
payload = self._create_payload(original_image, image_with_visualization, predictions)
63+
64+
self.__send_to_webhook(payload)

0 commit comments

Comments
 (0)