diff --git a/openwisp_controller/config/tasks.py b/openwisp_controller/config/tasks.py index 4c3a44b80..bb5243a8e 100644 --- a/openwisp_controller/config/tasks.py +++ b/openwisp_controller/config/tasks.py @@ -9,6 +9,8 @@ from openwisp_utils.tasks import OpenwispCeleryTask +from .utils import handle_error_notification, handle_recovery_notification + logger = logging.getLogger(__name__) @@ -108,6 +110,7 @@ def trigger_vpn_server_endpoint(endpoint, auth_token, vpn_id): # Cache the configuration here makes downloading the configuration faster. vpn.get_cached_configuration() + task_key = f"vpn_update_task:{vpn_id}" response = requests.post( endpoint, params={"key": auth_token}, @@ -115,12 +118,23 @@ def trigger_vpn_server_endpoint(endpoint, auth_token, vpn_id): ) if response.status_code == 200: logger.info(f"Triggered update webhook of VPN Server UUID: {vpn_id}") + handle_recovery_notification( + task_key, + instance=vpn, + action="update", + ) else: logger.error( "Failed to update VPN Server configuration. " f"Response status code: {response.status_code}, " f"VPN Server UUID: {vpn_id}", ) + handle_error_notification( + task_key, + response, + instance=vpn, + action="update", + ) @shared_task(soft_time_limit=7200) diff --git a/openwisp_controller/config/tasks_zerotier.py b/openwisp_controller/config/tasks_zerotier.py index f365ec6e9..1e14d04ea 100644 --- a/openwisp_controller/config/tasks_zerotier.py +++ b/openwisp_controller/config/tasks_zerotier.py @@ -1,12 +1,8 @@ import logging from http import HTTPStatus -from time import sleep from celery import shared_task -from django.core.cache import cache from django.core.exceptions import ObjectDoesNotExist -from django.utils.translation import gettext as _ -from openwisp_notifications.signals import notify from requests.exceptions import RequestException from swapper import load_model @@ -14,6 +10,7 @@ from openwisp_utils.tasks import OpenwispCeleryTask from .settings import API_TASK_RETRY_OPTIONS +from .utils import handle_error_notification, handle_recovery_notification logger = logging.getLogger(__name__) @@ -27,48 +24,6 @@ class OpenwispApiTask(OpenwispCeleryTask): HTTPStatus.GATEWAY_TIMEOUT, # 504 ] - def _send_api_task_notification(self, type, **kwargs): - vpn = kwargs.get("instance") - action = kwargs.get("action").replace("_", " ") - status_code = kwargs.get("status_code") - # Adding some delay here to prevent overlapping - # of the django success message container - # with the ow-notification container - # https://github.com/openwisp/openwisp-notifications/issues/264 - sleep(2) - message_map = { - "error": { - "verb": _("encountered an unrecoverable error"), - "message": _( - "Unable to perform {action} operation on the " - "{target} VPN server due to an " - "unrecoverable error " - "(status code: {status_code})" - ), - "level": "error", - }, - "recovery": { - "verb": _("has been completed successfully"), - "message": _("The {action} operation on {target} {verb}."), - "level": "info", - }, - } - meta = message_map[type] - notify.send( - type="generic_message", - sender=vpn, - target=vpn, - action=action, - verb=meta["verb"], - message=meta["message"].format( - action=action, - target=str(vpn), - status_code=status_code, - verb=meta["verb"], - ), - level=meta["level"], - ) - def handle_api_call(self, fn, *args, send_notification=True, **kwargs): """ This method handles API calls and their responses @@ -105,10 +60,7 @@ def handle_api_call(self, fn, *args, send_notification=True, **kwargs): response.raise_for_status() logger.info(info_msg) if send_notification: - task_result = cache.get(task_key) - if task_result == "error": - self._send_api_task_notification("recovery", **kwargs) - cache.set(task_key, "success", None) + handle_recovery_notification(task_key, **kwargs) except RequestException as e: if response.status_code in self._RECOVERABLE_API_CODES: retry_logger = logger.warn @@ -122,12 +74,7 @@ def handle_api_call(self, fn, *args, send_notification=True, **kwargs): raise e logger.error(f"{err_msg}, Error: {e}") if send_notification: - task_result = cache.get(task_key) - if task_result in (None, "success"): - cache.set(task_key, "error", None) - self._send_api_task_notification( - "error", status_code=response.status_code, **kwargs - ) + handle_error_notification(task_key, response, **kwargs) return (response, updated_config) if updated_config else response diff --git a/openwisp_controller/config/tests/test_vpn.py b/openwisp_controller/config/tests/test_vpn.py index d02200873..21456d8c0 100644 --- a/openwisp_controller/config/tests/test_vpn.py +++ b/openwisp_controller/config/tests/test_vpn.py @@ -516,11 +516,10 @@ def test_update_vpn_dh(self, dhparam): def test_vpn_server_change_invalidates_device_cache(self): device, vpn, template = self._create_wireguard_vpn_template() - with catch_signal( - vpn_server_modified - ) as mocked_vpn_server_modified, catch_signal( - config_modified - ) as mocked_config_modified: + with ( + catch_signal(vpn_server_modified) as mocked_vpn_server_modified, + catch_signal(config_modified) as mocked_config_modified, + ): vpn.host = "localhost" vpn.save(update_fields=["host"]) mocked_vpn_server_modified.assert_called_once_with( @@ -766,15 +765,18 @@ def test_auto_peer_configuration(self): "organization": device.organization, } ) - with mock.patch.object( - Vpn, - "invalidate_checksum_cache", - return_value=vpn.invalidate_checksum_cache(), - ) as mocked_invalidate_checksum_cache, mock.patch.object( - Vpn, - "get_cached_configuration", - return_value=vpn.get_cached_configuration(), - ) as mocked_cached_configuration: + with ( + mock.patch.object( + Vpn, + "invalidate_checksum_cache", + return_value=vpn.invalidate_checksum_cache(), + ) as mocked_invalidate_checksum_cache, + mock.patch.object( + Vpn, + "get_cached_configuration", + return_value=vpn.get_cached_configuration(), + ) as mocked_cached_configuration, + ): device2.config.templates.add(template) # The Vpn configuration cache is invalidated and re-populated mocked_invalidate_checksum_cache.assert_called_once() @@ -785,15 +787,18 @@ def test_auto_peer_configuration(self): self.assertEqual(len(vpn_config.get("peers", [])), 2) with self.subTest("cache updated when a peer is deleted"): - with mock.patch.object( - Vpn, - "invalidate_checksum_cache", - return_value=vpn.invalidate_checksum_cache(), - ) as mocked_invalidate_checksum_cache, mock.patch.object( - Vpn, - "get_cached_configuration", - return_value=vpn.get_cached_configuration(), - ) as mocked_cached_configuration: + with ( + mock.patch.object( + Vpn, + "invalidate_checksum_cache", + return_value=vpn.invalidate_checksum_cache(), + ) as mocked_invalidate_checksum_cache, + mock.patch.object( + Vpn, + "get_cached_configuration", + return_value=vpn.get_cached_configuration(), + ) as mocked_cached_configuration, + ): device2.delete(check_deactivated=False) mocked_invalidate_checksum_cache.assert_called_once() mocked_cached_configuration.assert_not_called() @@ -830,11 +835,14 @@ def test_update_vpn_server_configuration(self): vpn.auth_token = "super-secret-token" vpn.save() vpn_client.refresh_from_db() - - with mock.patch( - "openwisp_controller.config.tasks.logger.info" - ) as mocked_logger, mock.patch( - "requests.post", return_value=HttpResponse() + with ( + mock.patch( + "openwisp_controller.config.tasks.logger.info" + ) as mocked_logger, + mock.patch("requests.post", return_value=HttpResponse()), + mock.patch( + "openwisp_controller.config.tasks.handle_recovery_notification" + ) as mocked_recovery, ): post_save.send( instance=vpn_client, sender=vpn_client._meta.model, created=False @@ -842,9 +850,16 @@ def test_update_vpn_server_configuration(self): mocked_logger.assert_called_once_with( f"Triggered update webhook of VPN Server UUID: {vpn.pk}" ) - - with mock.patch("logging.Logger.error") as mocked_logger, mock.patch( - "requests.post", return_value=HttpResponseNotFound() + mocked_recovery.assert_called_once() + args, kwargs = mocked_recovery.call_args + self.assertEqual(kwargs["instance"], vpn) + self.assertEqual(kwargs["action"], "update") + with ( + mock.patch("logging.Logger.error") as mocked_logger, + mock.patch("requests.post", return_value=HttpResponseNotFound()), + mock.patch( + "openwisp_controller.config.tasks.handle_error_notification" + ) as mocked_error, ): post_save.send( instance=vpn_client, sender=vpn_client._meta.model, created=False @@ -853,6 +868,10 @@ def test_update_vpn_server_configuration(self): "Failed to update VPN Server configuration. " f"Response status code: 404, VPN Server UUID: {vpn.pk}" ) + mocked_error.assert_called_once() + args, kwargs = mocked_error.call_args + self.assertEqual(kwargs["instance"], vpn) + self.assertEqual(kwargs["action"], "update") def test_vpn_peers_changed(self): with self.subTest("VpnClient created"): @@ -1844,18 +1863,22 @@ def test_zerotier_update_vpn_server_configuration( mock_error.reset_mock() mock_requests.reset_mock() - with self.subTest( - "Test zerotier configuration update " - "with retry mechanism (recoverable errors)" - ), mock.patch("celery.app.task.Task.request") as mock_task_request: + with ( + self.subTest( + "Test zerotier configuration update " + "with retry mechanism (recoverable errors)" + ), + mock.patch("celery.app.task.Task.request") as mock_task_request, + ): max_retries = API_TASK_RETRY_OPTIONS.get("max_retries") mock_task_request.called_directly = False config = vpn.get_config()["zerotier"][0] config.update({"private": True}) - with self.subTest( - "Test update when max retry limit is not reached" - ), self.assertRaises(Retry): + with ( + self.subTest("Test update when max retry limit is not reached"), + self.assertRaises(Retry), + ): mock_requests.get.side_effect = [ # For node status self._get_mock_response(200, response=self._TEST_ZT_NODE_CONFIG) @@ -1906,9 +1929,10 @@ def test_zerotier_update_vpn_server_configuration( # During the last attempt, the task will give up # retrying and raise a 'RequestException', # which will be handled and logged as an error - with self.subTest( - "Test update when max retry limit is reached" - ), self.assertRaises(RequestException): + with ( + self.subTest("Test update when max retry limit is reached"), + self.assertRaises(RequestException), + ): mock_requests.get.side_effect = [ # For node status self._get_mock_response(200, response=self._TEST_ZT_NODE_CONFIG) diff --git a/openwisp_controller/config/utils.py b/openwisp_controller/config/utils.py index c0048fa0f..64e6edd6a 100644 --- a/openwisp_controller/config/utils.py +++ b/openwisp_controller/config/utils.py @@ -1,10 +1,14 @@ import logging +from time import sleep +from django.core.cache import cache from django.core.exceptions import ValidationError from django.db.models import Q from django.http import Http404, HttpResponse from django.shortcuts import get_object_or_404 as base_get_object_or_404 from django.urls import path, re_path +from django.utils.translation import gettext as _ +from openwisp_notifications.signals import notify from openwisp_notifications.utils import _get_object_link logger = logging.getLogger(__name__) @@ -206,3 +210,60 @@ def get_default_templates_queryset( def get_config_error_notification_target_url(obj, field, absolute_url=True): url = _get_object_link(obj._related_object(field), absolute_url) return f"{url}#config-group" + + +def send_api_task_notification(type, **kwargs): + vpn = kwargs.get("instance") + action = kwargs.get("action").replace("_", " ") + status_code = kwargs.get("status_code") + # Adding some delay here to prevent overlapping + # of the django success message container + # with the ow-notification container + # https://github.com/openwisp/openwisp-notifications/issues/264 + sleep(2) + message_map = { + "error": { + "verb": _("encountered an unrecoverable error"), + "message": _( + "Unable to perform {action} operation on the " + "{target} VPN server due to an " + "unrecoverable error " + "(status code: {status_code})" + ), + "level": "error", + }, + "recovery": { + "verb": _("has been completed successfully"), + "message": _("The {action} operation on {target} {verb}."), + "level": "info", + }, + } + meta = message_map[type] + notify.send( + type="generic_message", + sender=vpn, + target=vpn, + action=action, + verb=meta["verb"], + message=meta["message"].format( + action=action, + target=str(vpn), + status_code=status_code, + verb=meta["verb"], + ), + level=meta["level"], + ) + + +def handle_recovery_notification(task_key, **kwargs): + task_result = cache.get(task_key) + if task_result == "error": + send_api_task_notification("recovery", **kwargs) + cache.set(task_key, "success", None) + + +def handle_error_notification(task_key, response, **kwargs): + task_result = cache.get(task_key) + if task_result in (None, "success"): + cache.set(task_key, "error", None) + send_api_task_notification("error", status_code=response.status_code, **kwargs)