1818from typing import Callable , Dict , List , Optional , Union
1919
2020import pika
21+ import uvicorn
22+ from fastapi import FastAPI , Request
2123from filigran_sseclient import SSEClient
2224from pika .exceptions import NackError , UnroutableError
2325from pydantic import TypeAdapter
3032TRUTHY : List [str ] = ["yes" , "true" , "True" ]
3133FALSY : List [str ] = ["no" , "false" , "False" ]
3234
35+ app = FastAPI ()
36+
3337
3438def killProgramHook (etype , value , tb ):
3539 os .kill (os .getpid (), signal .SIGTERM )
@@ -141,6 +145,35 @@ def ssl_cert_chain(ssl_context, cert_data, key_data, passphrase):
141145 os .unlink (key_file_path )
142146
143147
148+ def create_callback_ssl_context (config ) -> ssl .SSLContext :
149+ listen_protocol_api_ssl_key = get_config_variable (
150+ "LISTEN_PROTOCOL_API_SSL_KEY" ,
151+ ["connector" , "listen_protocol_api_ssl_key" ],
152+ config ,
153+ default = "" ,
154+ )
155+ listen_protocol_api_ssl_cert = get_config_variable (
156+ "LISTEN_PROTOCOL_API_SSL_CERT" ,
157+ ["connector" , "listen_protocol_api_ssl_cert" ],
158+ config ,
159+ default = "" ,
160+ )
161+ listen_protocol_api_ssl_passphrase = get_config_variable (
162+ "LISTEN_PROTOCOL_API_SSL_PASSPHRASE" ,
163+ ["connector" , "listen_protocol_api_ssl_passphrase" ],
164+ config ,
165+ default = "" ,
166+ )
167+ ssl_context = ssl .create_default_context (ssl .Purpose .CLIENT_AUTH )
168+ ssl_cert_chain (
169+ ssl_context ,
170+ listen_protocol_api_ssl_cert ,
171+ listen_protocol_api_ssl_key ,
172+ listen_protocol_api_ssl_passphrase ,
173+ )
174+ return ssl_context
175+
176+
144177def create_mq_ssl_context (config ) -> ssl .SSLContext :
145178 use_ssl_ca = get_config_variable ("MQ_USE_SSL_CA" , ["mq" , "use_ssl_ca" ], config )
146179 use_ssl_cert = get_config_variable (
@@ -183,9 +216,13 @@ class ListenQueue(threading.Thread):
183216 def __init__ (
184217 self ,
185218 helper ,
219+ opencti_token ,
186220 config : Dict ,
187221 connector_config : Dict ,
188222 applicant_id ,
223+ listen_protocol ,
224+ listen_protocol_api_path ,
225+ listen_protocol_api_port ,
189226 callback ,
190227 ) -> None :
191228 threading .Thread .__init__ (self )
@@ -196,6 +233,10 @@ def __init__(
196233 self .helper = helper
197234 self .callback = callback
198235 self .config = config
236+ self .opencti_token = opencti_token
237+ self .listen_protocol = listen_protocol
238+ self .listen_protocol_api_path = listen_protocol_api_path
239+ self .listen_protocol_api_port = listen_protocol_api_port
199240 self .connector_applicant_id = applicant_id
200241 self .host = connector_config ["connection" ]["host" ]
201242 self .vhost = connector_config ["connection" ]["vhost" ]
@@ -375,52 +416,96 @@ def _data_handler(self, json_data) -> None:
375416 "Failing reporting the processing"
376417 )
377418
419+ async def _process_callback (self , request : Request ):
420+ # 01. Check the authentication
421+ try :
422+ authorization : str = request .headers .get ("Authorization" )
423+ scheme , token = authorization .split ()
424+ if scheme .lower () != "bearer" or token != self .opencti_token :
425+ return {"error" : "Invalid credentials" }
426+ except Exception as e :
427+ return {"error" : "Invalid credentials" }
428+ # 02. Parse the data and execute
429+ try :
430+ data = await request .json () # Get the JSON payload
431+ except Exception as e :
432+ return {"error" : "Invalid JSON payload" , "details" : str (e )}
433+ try :
434+ self ._data_handler (data )
435+ except Exception as e :
436+ return {"error" : "Error processing message" , "details" : str (e )}
437+ # all good
438+ return {"message" : "Message successfully processed" }
439+
378440 def run (self ) -> None :
379- self .helper .connector_logger .info ("Starting ListenQueue thread" )
380- while not self .exit_event .is_set ():
381- try :
382- self .helper .connector_logger .info ("ListenQueue connecting to rabbitMq." )
383- # Connect the broker
384- self .pika_credentials = pika .PlainCredentials (self .user , self .password )
385- self .pika_parameters = pika .ConnectionParameters (
386- heartbeat = 10 ,
387- blocked_connection_timeout = 30 ,
388- host = self .host ,
389- port = self .port ,
390- virtual_host = self .vhost ,
391- credentials = self .pika_credentials ,
392- ssl_options = (
393- pika .SSLOptions (create_mq_ssl_context (self .config ), self .host )
394- if self .use_ssl
395- else None
396- ),
397- )
398- self .pika_connection = pika .BlockingConnection (self .pika_parameters )
399- self .channel = self .pika_connection .channel ()
441+ if self .listen_protocol == "AMQP" :
442+ self .helper .connector_logger .info ("Starting ListenQueue thread" )
443+ while not self .exit_event .is_set ():
400444 try :
401- # confirm_delivery is only for cluster mode rabbitMQ
402- # when not in cluster mode this line raise an exception
403- self .channel .confirm_delivery ()
445+ self .helper .connector_logger .info (
446+ "ListenQueue connecting to rabbitMq."
447+ )
448+ # Connect the broker
449+ self .pika_credentials = pika .PlainCredentials (
450+ self .user , self .password
451+ )
452+ self .pika_parameters = pika .ConnectionParameters (
453+ heartbeat = 10 ,
454+ blocked_connection_timeout = 30 ,
455+ host = self .host ,
456+ port = self .port ,
457+ virtual_host = self .vhost ,
458+ credentials = self .pika_credentials ,
459+ ssl_options = (
460+ pika .SSLOptions (
461+ create_mq_ssl_context (self .config ), self .host
462+ )
463+ if self .use_ssl
464+ else None
465+ ),
466+ )
467+ self .pika_connection = pika .BlockingConnection (self .pika_parameters )
468+ self .channel = self .pika_connection .channel ()
469+ try :
470+ # confirm_delivery is only for cluster mode rabbitMQ
471+ # when not in cluster mode this line raise an exception
472+ self .channel .confirm_delivery ()
473+ except Exception as err : # pylint: disable=broad-except
474+ self .helper .connector_logger .debug (str (err ))
475+ self .channel .basic_qos (prefetch_count = 1 )
476+ assert self .channel is not None
477+ self .channel .basic_consume (
478+ queue = self .queue_name , on_message_callback = self ._process_message
479+ )
480+ self .channel .start_consuming ()
404481 except Exception as err : # pylint: disable=broad-except
405- self .helper .connector_logger .debug (str (err ))
406- self .channel .basic_qos (prefetch_count = 1 )
407- assert self .channel is not None
408- self .channel .basic_consume (
409- queue = self .queue_name , on_message_callback = self ._process_message
410- )
411- self .channel .start_consuming ()
412- except Exception as err : # pylint: disable=broad-except
413- try :
414- self .pika_connection .close ()
415- except Exception as errInException :
416- self .helper .connector_logger .debug (
417- type (errInException ).__name__ , {"reason" : str (errInException )}
482+ try :
483+ self .pika_connection .close ()
484+ except Exception as errInException :
485+ self .helper .connector_logger .debug (
486+ type (errInException ).__name__ ,
487+ {"reason" : str (errInException )},
488+ )
489+ self .helper .connector_logger .error (
490+ type (err ).__name__ , {"reason" : str (err )}
418491 )
419- self .helper .connector_logger .error (
420- type (err ).__name__ , {"reason" : str (err )}
421- )
422- # Wait some time and then retry ListenQueue again.
423- time .sleep (10 )
492+ # Wait some time and then retry ListenQueue again.
493+ time .sleep (10 )
494+ elif self .listen_protocol == "API" :
495+ app .add_api_route (
496+ self .listen_protocol_api_path , self ._process_callback , methods = ["POST" ]
497+ )
498+ ssl_ctx = create_callback_ssl_context (self .config )
499+ config = uvicorn .Config (
500+ app , host = "0.0.0.0" , port = self .listen_protocol_api_port , reload = False
501+ )
502+ config .load () # Manually calling the .load() to trigger needed actions outside HTTPS
503+ config .ssl = ssl_ctx
504+ server = uvicorn .Server (config )
505+ server .run ()
506+
507+ else :
508+ raise ValueError ("Unsupported listen protocol type" )
424509
425510 def stop (self ):
426511 self .helper .connector_logger .info ("Preparing ListenQueue for clean shutdown" )
@@ -790,6 +875,38 @@ def __init__(self, config: Dict, playbook_compatible=False) -> None:
790875 self .connect_id = get_config_variable (
791876 "CONNECTOR_ID" , ["connector" , "id" ], config
792877 )
878+ self .listen_protocol = get_config_variable (
879+ "LISTEN_PROTOCOL" , ["connector" , "listen_protocol" ], config , default = "AMQP"
880+ ).upper ()
881+ self .listen_protocol_api_port = get_config_variable (
882+ "LISTEN_PROTOCOL_API_PORT" ,
883+ ["connector" , "listen_protocol_api_port" ],
884+ config ,
885+ default = 7070 ,
886+ )
887+ self .listen_protocol_api_path = get_config_variable (
888+ "LISTEN_PROTOCOL_API_PATH" ,
889+ ["connector" , "listen_protocol_api_path" ],
890+ config ,
891+ default = "/api/callback" ,
892+ )
893+ self .listen_protocol_api_ssl = get_config_variable (
894+ "LISTEN_PROTOCOL_API_SSL" ,
895+ ["connector" , "listen_protocol_api_ssl" ],
896+ config ,
897+ default = False ,
898+ )
899+
900+ self .listen_protocol_api_uri = get_config_variable (
901+ "LISTEN_PROTOCOL_API_URI" ,
902+ ["connector" , "listen_protocol_api_uri" ],
903+ config ,
904+ default = (
905+ "https://127.0.0.1:7070"
906+ if self .listen_protocol_api_ssl
907+ else "http://127.0.0.1:7070"
908+ ),
909+ )
793910 self .queue_protocol = get_config_variable (
794911 "QUEUE_PROTOCOL" , ["connector" , "queue_protocol" ], config , default = "amqp"
795912 )
@@ -957,6 +1074,7 @@ def __init__(self, config: Dict, playbook_compatible=False) -> None:
9571074 self .connect_auto ,
9581075 self .connect_only_contextual ,
9591076 playbook_compatible ,
1077+ self .listen_protocol_api_uri + self .listen_protocol_api_path ,
9601078 )
9611079 connector_configuration = self .api .connector .register (self .connector )
9621080 self .connector_logger .info (
@@ -1441,9 +1559,13 @@ def listen(
14411559
14421560 self .listen_queue = ListenQueue (
14431561 self ,
1562+ self .opencti_token ,
14441563 self .config ,
14451564 self .connector_config ,
14461565 self .applicant_id ,
1566+ self .listen_protocol ,
1567+ self .listen_protocol_api_path ,
1568+ self .listen_protocol_api_port ,
14471569 message_callback ,
14481570 )
14491571 self .listen_queue .start ()
@@ -1742,13 +1864,13 @@ def send_stix2_bundle(self, bundle: str, **kwargs) -> list:
17421864 raise ValueError ("Nothing to import" )
17431865
17441866 if bundle_send_to_queue :
1745- if work_id :
1746- self .api .work .add_expectations (work_id , expectations_number )
1747- if draft_id :
1748- self .api .work .add_draft_context (work_id , draft_id )
1867+ if work_id and draft_id :
1868+ self .api .work .add_draft_context (work_id , draft_id )
17491869 if entities_types is None :
17501870 entities_types = []
17511871 if self .queue_protocol == "amqp" :
1872+ if work_id :
1873+ self .api .work .add_expectations (work_id , expectations_number )
17521874 pika_credentials = pika .PlainCredentials (
17531875 self .connector_config ["connection" ]["user" ],
17541876 self .connector_config ["connection" ]["pass" ],
@@ -1791,7 +1913,7 @@ def send_stix2_bundle(self, bundle: str, **kwargs) -> list:
17911913 pika_connection .close ()
17921914 elif self .queue_protocol == "api" :
17931915 self .api .send_bundle_to_api (
1794- connector_id = self .connector_id , bundle = bundle
1916+ connector_id = self .connector_id , bundle = bundle , work_id = work_id
17951917 )
17961918 else :
17971919 raise ValueError (
0 commit comments