|
3 | 3 | import json |
4 | 4 | import logging |
5 | 5 | import os |
6 | | -import queue |
7 | 6 | import ssl |
8 | 7 | import threading |
9 | 8 | import time |
10 | 9 | import uuid |
11 | 10 | from typing import Callable, Dict, List, Optional, Union |
12 | 11 |
|
13 | 12 | import pika |
14 | | -import requests |
15 | 13 | from pika.exceptions import NackError, UnroutableError |
16 | 14 | from sseclient import SSEClient |
17 | 15 |
|
18 | 16 | from pycti.api.opencti_api_client import OpenCTIApiClient |
19 | 17 | from pycti.connector.opencti_connector import OpenCTIConnector |
20 | 18 | from pycti.utils.opencti_stix2_splitter import OpenCTIStix2Splitter |
21 | 19 |
|
22 | | -EVENTS_QUEUE = queue.Queue() |
23 | | - |
24 | 20 |
|
25 | 21 | def get_config_variable( |
26 | 22 | env_var: str, |
@@ -222,81 +218,114 @@ def run(self): |
222 | 218 | self.ping() |
223 | 219 |
|
224 | 220 |
|
225 | | -class StreamCatcher(threading.Thread): |
226 | | - def __init__( |
227 | | - self, |
228 | | - opencti_url, |
229 | | - opencti_token, |
230 | | - connector_last_event_id, |
231 | | - last_event_id, |
232 | | - stream_connection_id, |
233 | | - ): |
| 221 | +class ListenStream(threading.Thread): |
| 222 | + def __init__(self, helper, callback, url, token, verify_ssl): |
234 | 223 | threading.Thread.__init__(self) |
235 | | - self.opencti_url = opencti_url |
236 | | - self.opencti_token = opencti_token |
237 | | - self.connector_last_event_id = connector_last_event_id |
238 | | - self.last_event_id = last_event_id |
239 | | - self.stream_connection_id = stream_connection_id |
240 | | - self.session = requests.session() |
241 | | - |
242 | | - def get_range(self, from_id): |
243 | | - payload = { |
244 | | - "from": from_id, |
245 | | - "size": 2000, |
246 | | - "connectionId": self.stream_connection_id, |
247 | | - } |
248 | | - headers = {"Authorization": "Bearer " + self.opencti_token} |
249 | | - r = self.session.post( |
250 | | - self.opencti_url + "/stream/history", json=payload, headers=headers |
251 | | - ) |
252 | | - result = r.json() |
253 | | - if result and "lastEventId" in result: |
254 | | - return result["lastEventId"] |
| 224 | + self.helper = helper |
| 225 | + self.callback = callback |
| 226 | + self.url = url |
| 227 | + self.token = token |
| 228 | + self.verify_ssl = verify_ssl |
255 | 229 |
|
256 | 230 | def run(self): |
257 | | - if self.connector_last_event_id: |
258 | | - from_event_id = self.connector_last_event_id |
259 | | - # If from event ID is "-", start from the beginning |
260 | | - if from_event_id == "-": |
261 | | - from_event_timestamp = 0 |
262 | | - # If from event ID is a "pure" timestamp |
263 | | - elif "-" not in str(from_event_id): |
264 | | - from_event_timestamp = int(from_event_id) |
265 | | - elif "-" in str(from_event_id): |
266 | | - from_event_timestamp = int(str(from_event_id).split("-")[0]) |
| 231 | + current_state = self.helper.get_state() |
| 232 | + if current_state is None: |
| 233 | + current_state = {"connectorLastEventId": "-"} |
| 234 | + self.helper.set_state(current_state) |
| 235 | + |
| 236 | + # If URL and token are provided, likely consuming a remote stream |
| 237 | + if self.url is not None and self.token is not None: |
| 238 | + # If a live stream ID, appending the URL |
| 239 | + live_stream_uri = ( |
| 240 | + ("/" + self.helper.connect_live_stream_id) |
| 241 | + if self.helper.connect_live_stream_id is not None |
| 242 | + else "" |
| 243 | + ) |
| 244 | + # Live stream "from" should be empty if start from the beginning |
| 245 | + if self.helper.connect_live_stream_id is not None: |
| 246 | + live_stream_from = ( |
| 247 | + ("?from=" + current_state["connectorLastEventId"]) |
| 248 | + if current_state["connectorLastEventId"] != "-" |
| 249 | + else "" |
| 250 | + ) |
| 251 | + # Global stream "from" should be 0 if starting from the beginning |
267 | 252 | else: |
268 | | - from_event_timestamp = 0 |
269 | | - last_event_timestamp = int(self.last_event_id.split("-")[0]) |
270 | | - if from_event_timestamp > last_event_timestamp: |
271 | | - from_event_timestamp = last_event_timestamp - 1 |
272 | | - from_event_id = str(from_event_timestamp) + "-0" |
273 | | - while ( |
274 | | - from_event_timestamp <= last_event_timestamp |
275 | | - and from_event_id != self.last_event_id |
276 | | - ): |
277 | | - from_event_id = self.get_range(from_event_id) |
278 | | - from_event_timestamp = int(from_event_id.split("-")[0]) |
279 | | - logging.info("Events catchup requests done.") |
280 | | - |
281 | | - |
282 | | -class StreamProcessor(threading.Thread): |
283 | | - def __init__(self, message_callback, get_state, set_state): |
284 | | - threading.Thread.__init__(self) |
285 | | - self.message_callback = message_callback |
286 | | - self.get_state = get_state |
287 | | - self.set_state = set_state |
| 253 | + live_stream_from = "?from=" + ( |
| 254 | + current_state["connectorLastEventId"] |
| 255 | + if current_state["connectorLastEventId"] != "-" |
| 256 | + else "0" |
| 257 | + ) |
| 258 | + live_stream_url = self.url + "/stream" + live_stream_uri + live_stream_from |
| 259 | + opencti_ssl_verify = ( |
| 260 | + self.verify_ssl if self.verify_ssl is not None else True |
| 261 | + ) |
| 262 | + logging.info( |
| 263 | + "Starting listening stream events (URL: " |
| 264 | + + live_stream_url |
| 265 | + + ", SSL verify: " |
| 266 | + + str(opencti_ssl_verify) |
| 267 | + + ")" |
| 268 | + ) |
| 269 | + messages = SSEClient( |
| 270 | + live_stream_url, |
| 271 | + headers={"Authorization": "Bearer " + self.token}, |
| 272 | + verify=opencti_ssl_verify, |
| 273 | + ) |
| 274 | + else: |
| 275 | + live_stream_uri = ( |
| 276 | + ("/" + self.helper.connect_live_stream_id) |
| 277 | + if self.helper.connect_live_stream_id is not None |
| 278 | + else "" |
| 279 | + ) |
| 280 | + if self.helper.connect_live_stream_id is not None: |
| 281 | + live_stream_from = ( |
| 282 | + ("?from=" + current_state["connectorLastEventId"]) |
| 283 | + if current_state["connectorLastEventId"] != "-" |
| 284 | + else "" |
| 285 | + ) |
| 286 | + # Global stream "from" should be 0 if starting from the beginning |
| 287 | + else: |
| 288 | + live_stream_from = "?from=" + ( |
| 289 | + current_state["connectorLastEventId"] |
| 290 | + if current_state["connectorLastEventId"] != "-" |
| 291 | + else "0" |
| 292 | + ) |
| 293 | + live_stream_url = ( |
| 294 | + self.helper.opencti_url + "/stream" + live_stream_uri + live_stream_from |
| 295 | + ) |
| 296 | + logging.info( |
| 297 | + "Starting listening stream events (URL: " |
| 298 | + + live_stream_url |
| 299 | + + ", SSL verify: " |
| 300 | + + str(self.helper.opencti_ssl_verify) |
| 301 | + + ")" |
| 302 | + ) |
| 303 | + messages = SSEClient( |
| 304 | + live_stream_url, |
| 305 | + headers={"Authorization": "Bearer " + self.helper.opencti_token}, |
| 306 | + verify=self.helper.opencti_ssl_verify, |
| 307 | + ) |
288 | 308 |
|
289 | | - def run(self): |
290 | | - logging.info("All old events processed, consuming is now LIVE!") |
291 | | - while True: |
292 | | - msg = EVENTS_QUEUE.get(block=True, timeout=None) |
293 | | - self.message_callback(msg) |
294 | | - state = self.get_state() |
295 | | - if state is not None: |
296 | | - state["connectorLastEventId"] = msg.id |
297 | | - self.set_state(state) |
| 309 | + for msg in messages: |
| 310 | + if msg.event == "heartbeat": |
| 311 | + logging.info("HEARTBEAT:" + str(msg)) |
| 312 | + continue |
| 313 | + elif msg.event == "connected": |
| 314 | + logging.info("CONNECTED:" + str(msg)) |
| 315 | + elif msg.event == "catch": |
| 316 | + logging.info("Catchup done") |
298 | 317 | else: |
299 | | - self.set_state({"connectorLastEventId": msg.id}) |
| 318 | + event_id = msg.id |
| 319 | + date = datetime.datetime.fromtimestamp( |
| 320 | + round(int(event_id.split("-")[0]) / 1000) |
| 321 | + ) |
| 322 | + logging.info( |
| 323 | + "Processing message (id: " + event_id + ", date: " + str(date) + ")" |
| 324 | + ) |
| 325 | + self.callback(msg) |
| 326 | + state = self.helper.get_state() |
| 327 | + state["connectorLastEventId"] = str(msg.id) |
| 328 | + self.helper.set_state(state) |
300 | 329 |
|
301 | 330 |
|
302 | 331 | class OpenCTIConnectorHelper: |
@@ -324,6 +353,13 @@ def __init__(self, config: dict): |
324 | 353 | self.connect_type = get_config_variable( |
325 | 354 | "CONNECTOR_TYPE", ["connector", "type"], config |
326 | 355 | ) |
| 356 | + self.connect_live_stream_id = get_config_variable( |
| 357 | + "CONNECTOR_LIVE_STREAM_ID", |
| 358 | + ["connector", "live_stream_id"], |
| 359 | + config, |
| 360 | + False, |
| 361 | + None, |
| 362 | + ) |
327 | 363 | self.connect_name = get_config_variable( |
328 | 364 | "CONNECTOR_NAME", ["connector", "name"], config |
329 | 365 | ) |
@@ -427,97 +463,19 @@ def listen(self, message_callback: Callable[[str, Dict], str]) -> None: |
427 | 463 | listen_queue.start() |
428 | 464 |
|
429 | 465 | def listen_stream( |
430 | | - self, message_callback, url=None, token=None, verify=None |
| 466 | + self, |
| 467 | + message_callback, |
| 468 | + url=None, |
| 469 | + token=None, |
| 470 | + verify_ssl=None, |
431 | 471 | ) -> None: |
432 | 472 | """listen for messages and register callback function |
433 | 473 |
|
434 | 474 | :param message_callback: callback function to process messages |
435 | 475 | """ |
436 | | - current_state = self.get_state() |
437 | | - if current_state is None: |
438 | | - current_state = {"connectorLastEventId": "-"} |
439 | 476 |
|
440 | | - # Get the last event ID with the "connected" event msg |
441 | | - if url is not None and token is not None: |
442 | | - opencti_ssl_verify = verify if verify is not None else True |
443 | | - logging.info( |
444 | | - "Starting listening stream events with SSL verify to: " |
445 | | - + str(opencti_ssl_verify) |
446 | | - ) |
447 | | - messages = SSEClient( |
448 | | - url + "/stream", |
449 | | - headers={"Authorization": "Bearer " + token}, |
450 | | - verify=opencti_ssl_verify, |
451 | | - ) |
452 | | - else: |
453 | | - logging.info( |
454 | | - "Starting listening stream events with SSL verify to: " |
455 | | - + str(self.opencti_ssl_verify) |
456 | | - ) |
457 | | - messages = SSEClient( |
458 | | - self.opencti_url + "/stream", |
459 | | - headers={"Authorization": "Bearer " + self.opencti_token}, |
460 | | - verify=self.opencti_ssl_verify, |
461 | | - ) |
462 | | - |
463 | | - # Create processor thread |
464 | | - processor_thread = StreamProcessor( |
465 | | - message_callback, self.get_state, self.set_state |
466 | | - ) |
467 | | - |
468 | | - last_event_id = None |
469 | | - for msg in messages: |
470 | | - try: |
471 | | - data = json.loads(msg.data) |
472 | | - except: |
473 | | - logging.error("Failed to load JSON: " + msg.data) |
474 | | - continue |
475 | | - if msg.event == "heartbeat": |
476 | | - logging.info("HEARTBEAT:" + str(msg)) |
477 | | - continue |
478 | | - elif msg.event == "connected": |
479 | | - last_event_id = data["lastEventId"] |
480 | | - stream_connection_id = data["connectionId"] |
481 | | - # Launch processor if up to date |
482 | | - if current_state["connectorLastEventId"] == last_event_id: |
483 | | - processor_thread.start() |
484 | | - # Launch catcher if not up to date |
485 | | - if last_event_id != current_state["connectorLastEventId"]: |
486 | | - logging.info( |
487 | | - "Some events have not been processed, catching them..." |
488 | | - ) |
489 | | - if url is not None and token is not None: |
490 | | - catcher_thread = StreamCatcher( |
491 | | - url, |
492 | | - token, |
493 | | - current_state["connectorLastEventId"], |
494 | | - last_event_id, |
495 | | - stream_connection_id, |
496 | | - ) |
497 | | - else: |
498 | | - catcher_thread = StreamCatcher( |
499 | | - self.opencti_url, |
500 | | - self.opencti_token, |
501 | | - current_state["connectorLastEventId"], |
502 | | - last_event_id, |
503 | | - stream_connection_id, |
504 | | - ) |
505 | | - catcher_thread.start() |
506 | | - else: |
507 | | - # If receiving the last message, launch processor |
508 | | - if msg.id == last_event_id: |
509 | | - message_callback(msg) |
510 | | - processor_thread.start() |
511 | | - elif "catchup" not in data: |
512 | | - EVENTS_QUEUE.put(msg) |
513 | | - else: |
514 | | - message_callback(msg) |
515 | | - state = self.get_state() |
516 | | - if state is not None: |
517 | | - state["connectorLastEventId"] = msg.id |
518 | | - self.set_state(state) |
519 | | - else: |
520 | | - self.set_state({"connectorLastEventId": msg.id}) |
| 477 | + listen_stream = ListenStream(self, message_callback, url, token, verify_ssl) |
| 478 | + listen_stream.start() |
521 | 479 |
|
522 | 480 | def get_opencti_url(self): |
523 | 481 | return self.opencti_url |
|
0 commit comments