1+ import asyncio
12import base64
23import datetime
34import json
1516from typing import Callable , Dict , List , Optional , Union
1617
1718import pika
19+ from pika .adapters .asyncio_connection import AsyncioConnection
1820from pika .exceptions import NackError , UnroutableError
19- from sseclient import SSEClient
20-
2121from pycti .api .opencti_api_client import OpenCTIApiClient
2222from pycti .connector import LOGGER
2323from pycti .connector .opencti_connector import OpenCTIConnector
2424from pycti .utils .opencti_stix2_splitter import OpenCTIStix2Splitter
25+ from sseclient import SSEClient
2526
2627TRUTHY : List [str ] = ["yes" , "true" , "True" ]
2728FALSY : List [str ] = ["no" , "false" , "False" ]
@@ -37,6 +38,11 @@ def killProgramHook(etype, value, tb):
3738sys .excepthook = killProgramHook
3839
3940
41+ def run_loop (loop ):
42+ asyncio .set_event_loop (loop )
43+ loop .run_forever ()
44+
45+
4046def get_config_variable (
4147 env_var : str ,
4248 yaml_path : List ,
@@ -99,7 +105,7 @@ def create_ssl_context() -> ssl.SSLContext:
99105 return ssl_context
100106
101107
102- class ListenQueue ( threading . Thread ) :
108+ class ListenQueue :
103109 """Main class for the ListenQueue used in OpenCTIConnectorHelper
104110
105111 :param helper: instance of a `OpenCTIConnectorHelper` class
@@ -111,7 +117,6 @@ class ListenQueue(threading.Thread):
111117 """
112118
113119 def __init__ (self , helper , config : Dict , callback ) -> None :
114- threading .Thread .__init__ (self )
115120 self .pika_credentials = None
116121 self .pika_parameters = None
117122 self .pika_connection = None
@@ -126,10 +131,14 @@ def __init__(self, helper, config: Dict, callback) -> None:
126131 self .password = config ["connection" ]["pass" ]
127132 self .queue_name = config ["listen" ]
128133 self .exit_event = threading .Event ()
129- self .thread = None
134+ self .connector_thread = None
135+ self .connector_event_loop = None
136+ self .queue_event_loop = asyncio .new_event_loop ()
137+ asyncio .set_event_loop (self .queue_event_loop )
138+ self .run ()
130139
131140 # noinspection PyUnusedLocal
132- def _process_message (self , channel , method , properties , body ) -> None :
141+ async def _process_message (self , channel , method , properties , body ) -> None :
133142 """process a message from the rabbit queue
134143
135144 :param channel: channel instance
@@ -144,20 +153,27 @@ def _process_message(self, channel, method, properties, body) -> None:
144153
145154 json_data = json .loads (body )
146155 channel .basic_ack (delivery_tag = method .delivery_tag )
147- self .thread = threading .Thread (target = self ._data_handler , args = [json_data ])
148- self .thread .start ()
156+ message_task = self ._data_handler (json_data )
149157 five_minutes = 60 * 5
150158 time_wait = 0
151- while self .thread .is_alive (): # Loop while the thread is processing
152- if (
153- self .helper .work_id is not None and time_wait > five_minutes
154- ): # Ping every 5 minutes
155- self .helper .api .work .ping (self .helper .work_id )
156- time_wait = 0
157- else :
158- time_wait += 1
159- time .sleep (1 )
160-
159+ try :
160+ while not message_task .done (): # Loop while the task/thread is processing
161+ if (
162+ self .helper .work_id is not None and time_wait > five_minutes
163+ ): # Ping every 5 minutes
164+ self .helper .api .work .ping (self .helper .work_id )
165+ time_wait = 0
166+ else :
167+ time_wait += 1
168+ await asyncio .sleep (1 )
169+ self .helper .api .work .to_processed (
170+ json_data ["internal" ]["work_id" ], message_task .result ()
171+ )
172+ except Exception as e : # pylint: disable=broad-except
173+ logging .exception ("Error in message processing, reporting error to API" )
174+ self .helper .api .work .to_processed (
175+ json_data ["internal" ]["work_id" ], str (e ), True
176+ )
161177 LOGGER .info (
162178 "Message (delivery_tag=%s) processed, thread terminated" ,
163179 method .delivery_tag ,
@@ -176,8 +192,15 @@ def _data_handler(self, json_data) -> None:
176192 self .helper .api .work .to_received (
177193 work_id , "Connector ready to process the operation"
178194 )
179- message = self .callback (json_data ["event" ])
180- self .helper .api .work .to_processed (work_id , message )
195+ if asyncio .iscoroutinefunction (self .callback ):
196+ message = asyncio .run_coroutine_threadsafe (
197+ self .callback (json_data ["event" ]), self .connector_event_loop
198+ )
199+ else :
200+ message = asyncio .get_running_loop ().run_in_executor (
201+ None , self .callback , json_data ["event" ]
202+ )
203+ return message
181204 except Exception as e : # pylint: disable=broad-except
182205 LOGGER .exception ("Error in message processing, reporting error to API" )
183206 try :
@@ -199,24 +222,44 @@ def run(self) -> None:
199222 if self .use_ssl
200223 else None ,
201224 )
202- self .pika_connection = pika .BlockingConnection (self .pika_parameters )
203- self .channel = self .pika_connection .channel ()
204- assert self .channel is not None
205- self .channel .basic_consume (
206- queue = self .queue_name , on_message_callback = self ._process_message
225+ if asyncio .iscoroutinefunction (self .callback ):
226+ self .connector_event_loop = asyncio .new_event_loop ()
227+ self .connector_thread = threading .Thread (
228+ target = lambda : run_loop (self .connector_event_loop )
229+ ).start ()
230+ self .pika_connection = AsyncioConnection (
231+ self .pika_parameters ,
232+ on_open_callback = self .on_connection_open ,
233+ custom_ioloop = self .queue_event_loop ,
207234 )
208- self .channel . start_consuming ()
235+ self .pika_connection . ioloop . run_forever ()
209236 except (KeyboardInterrupt , SystemExit ):
210237 LOGGER .info ("Connector stop" )
211238 sys .exit (0 )
212239 except Exception as e : # pylint: disable=broad-except
213240 LOGGER .error ("%s" , e )
214241 time .sleep (10 )
215242
243+ # noinspection PyUnusedLocal
244+ def on_connection_open (self , _unused_connection ):
245+ self .pika_connection .channel (on_open_callback = self .on_channel_open )
246+
247+ def on_channel_open (self , channel ):
248+ self .channel = channel
249+ assert self .channel is not None
250+ self .channel .basic_consume (
251+ queue = self .queue_name ,
252+ on_message_callback = lambda * args : asyncio .create_task (
253+ self ._process_message (* args )
254+ ),
255+ )
256+
216257 def stop (self ):
258+ self .queue_event_loop .stop ()
217259 self .exit_event .set ()
218- if self .thread :
219- self .thread .join ()
260+ if self .connector_thread :
261+ self .connector_event_loop .stop ()
262+ self .connector_thread .join ()
220263
221264
222265class PingAlive (threading .Thread ):
@@ -650,7 +693,6 @@ def listen(self, message_callback: Callable[[Dict], str]) -> None:
650693 """
651694
652695 self .listen_queue = ListenQueue (self , self .config , message_callback )
653- self .listen_queue .start ()
654696
655697 def listen_stream (
656698 self ,
0 commit comments