Skip to content
This repository was archived by the owner on Dec 5, 2025. It is now read-only.

Commit d9bccc8

Browse files
authored
[client] Added async to connectors (#320)
1 parent 1662377 commit d9bccc8

File tree

1 file changed

+71
-29
lines changed

1 file changed

+71
-29
lines changed

pycti/connector/opencti_connector_helper.py

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import base64
23
import datetime
34
import json
@@ -15,13 +16,13 @@
1516
from typing import Callable, Dict, List, Optional, Union
1617

1718
import pika
19+
from pika.adapters.asyncio_connection import AsyncioConnection
1820
from pika.exceptions import NackError, UnroutableError
19-
from sseclient import SSEClient
20-
2121
from pycti.api.opencti_api_client import OpenCTIApiClient
2222
from pycti.connector import LOGGER
2323
from pycti.connector.opencti_connector import OpenCTIConnector
2424
from pycti.utils.opencti_stix2_splitter import OpenCTIStix2Splitter
25+
from sseclient import SSEClient
2526

2627
TRUTHY: List[str] = ["yes", "true", "True"]
2728
FALSY: List[str] = ["no", "false", "False"]
@@ -37,6 +38,11 @@ def killProgramHook(etype, value, tb):
3738
sys.excepthook = killProgramHook
3839

3940

41+
def run_loop(loop):
42+
asyncio.set_event_loop(loop)
43+
loop.run_forever()
44+
45+
4046
def get_config_variable(
4147
env_var: str,
4248
yaml_path: List,
@@ -99,7 +105,7 @@ def create_ssl_context() -> ssl.SSLContext:
99105
return ssl_context
100106

101107

102-
class ListenQueue(threading.Thread):
108+
class ListenQueue:
103109
"""Main class for the ListenQueue used in OpenCTIConnectorHelper
104110
105111
:param helper: instance of a `OpenCTIConnectorHelper` class
@@ -111,7 +117,6 @@ class ListenQueue(threading.Thread):
111117
"""
112118

113119
def __init__(self, helper, config: Dict, callback) -> None:
114-
threading.Thread.__init__(self)
115120
self.pika_credentials = None
116121
self.pika_parameters = None
117122
self.pika_connection = None
@@ -126,10 +131,14 @@ def __init__(self, helper, config: Dict, callback) -> None:
126131
self.password = config["connection"]["pass"]
127132
self.queue_name = config["listen"]
128133
self.exit_event = threading.Event()
129-
self.thread = None
134+
self.connector_thread = None
135+
self.connector_event_loop = None
136+
self.queue_event_loop = asyncio.new_event_loop()
137+
asyncio.set_event_loop(self.queue_event_loop)
138+
self.run()
130139

131140
# noinspection PyUnusedLocal
132-
def _process_message(self, channel, method, properties, body) -> None:
141+
async def _process_message(self, channel, method, properties, body) -> None:
133142
"""process a message from the rabbit queue
134143
135144
:param channel: channel instance
@@ -144,20 +153,27 @@ def _process_message(self, channel, method, properties, body) -> None:
144153

145154
json_data = json.loads(body)
146155
channel.basic_ack(delivery_tag=method.delivery_tag)
147-
self.thread = threading.Thread(target=self._data_handler, args=[json_data])
148-
self.thread.start()
156+
message_task = self._data_handler(json_data)
149157
five_minutes = 60 * 5
150158
time_wait = 0
151-
while self.thread.is_alive(): # Loop while the thread is processing
152-
if (
153-
self.helper.work_id is not None and time_wait > five_minutes
154-
): # Ping every 5 minutes
155-
self.helper.api.work.ping(self.helper.work_id)
156-
time_wait = 0
157-
else:
158-
time_wait += 1
159-
time.sleep(1)
160-
159+
try:
160+
while not message_task.done(): # Loop while the task/thread is processing
161+
if (
162+
self.helper.work_id is not None and time_wait > five_minutes
163+
): # Ping every 5 minutes
164+
self.helper.api.work.ping(self.helper.work_id)
165+
time_wait = 0
166+
else:
167+
time_wait += 1
168+
await asyncio.sleep(1)
169+
self.helper.api.work.to_processed(
170+
json_data["internal"]["work_id"], message_task.result()
171+
)
172+
except Exception as e: # pylint: disable=broad-except
173+
logging.exception("Error in message processing, reporting error to API")
174+
self.helper.api.work.to_processed(
175+
json_data["internal"]["work_id"], str(e), True
176+
)
161177
LOGGER.info(
162178
"Message (delivery_tag=%s) processed, thread terminated",
163179
method.delivery_tag,
@@ -176,8 +192,15 @@ def _data_handler(self, json_data) -> None:
176192
self.helper.api.work.to_received(
177193
work_id, "Connector ready to process the operation"
178194
)
179-
message = self.callback(json_data["event"])
180-
self.helper.api.work.to_processed(work_id, message)
195+
if asyncio.iscoroutinefunction(self.callback):
196+
message = asyncio.run_coroutine_threadsafe(
197+
self.callback(json_data["event"]), self.connector_event_loop
198+
)
199+
else:
200+
message = asyncio.get_running_loop().run_in_executor(
201+
None, self.callback, json_data["event"]
202+
)
203+
return message
181204
except Exception as e: # pylint: disable=broad-except
182205
LOGGER.exception("Error in message processing, reporting error to API")
183206
try:
@@ -199,24 +222,44 @@ def run(self) -> None:
199222
if self.use_ssl
200223
else None,
201224
)
202-
self.pika_connection = pika.BlockingConnection(self.pika_parameters)
203-
self.channel = self.pika_connection.channel()
204-
assert self.channel is not None
205-
self.channel.basic_consume(
206-
queue=self.queue_name, on_message_callback=self._process_message
225+
if asyncio.iscoroutinefunction(self.callback):
226+
self.connector_event_loop = asyncio.new_event_loop()
227+
self.connector_thread = threading.Thread(
228+
target=lambda: run_loop(self.connector_event_loop)
229+
).start()
230+
self.pika_connection = AsyncioConnection(
231+
self.pika_parameters,
232+
on_open_callback=self.on_connection_open,
233+
custom_ioloop=self.queue_event_loop,
207234
)
208-
self.channel.start_consuming()
235+
self.pika_connection.ioloop.run_forever()
209236
except (KeyboardInterrupt, SystemExit):
210237
LOGGER.info("Connector stop")
211238
sys.exit(0)
212239
except Exception as e: # pylint: disable=broad-except
213240
LOGGER.error("%s", e)
214241
time.sleep(10)
215242

243+
# noinspection PyUnusedLocal
244+
def on_connection_open(self, _unused_connection):
245+
self.pika_connection.channel(on_open_callback=self.on_channel_open)
246+
247+
def on_channel_open(self, channel):
248+
self.channel = channel
249+
assert self.channel is not None
250+
self.channel.basic_consume(
251+
queue=self.queue_name,
252+
on_message_callback=lambda *args: asyncio.create_task(
253+
self._process_message(*args)
254+
),
255+
)
256+
216257
def stop(self):
258+
self.queue_event_loop.stop()
217259
self.exit_event.set()
218-
if self.thread:
219-
self.thread.join()
260+
if self.connector_thread:
261+
self.connector_event_loop.stop()
262+
self.connector_thread.join()
220263

221264

222265
class PingAlive(threading.Thread):
@@ -650,7 +693,6 @@ def listen(self, message_callback: Callable[[Dict], str]) -> None:
650693
"""
651694

652695
self.listen_queue = ListenQueue(self, self.config, message_callback)
653-
self.listen_queue.start()
654696

655697
def listen_stream(
656698
self,

0 commit comments

Comments
 (0)