diff --git a/source/app/blueprints/graphql/cases.py b/source/app/blueprints/graphql/cases.py index a577936da..fad98020e 100644 --- a/source/app/blueprints/graphql/cases.py +++ b/source/app/blueprints/graphql/cases.py @@ -103,7 +103,7 @@ def mutate(root, info, name, description, client_id, soc_id=None, classification request['case_soc_id'] = soc_id if classification_id: request['classification_id'] = classification_id - case, _ = cases_create(request) + case = cases_create(request) return CaseCreate(case=case) diff --git a/source/app/blueprints/rest/manage/manage_cases_routes.py b/source/app/blueprints/rest/manage/manage_cases_routes.py index 07167dbf9..a6871c77f 100644 --- a/source/app/blueprints/rest/manage/manage_cases_routes.py +++ b/source/app/blueprints/rest/manage/manage_cases_routes.py @@ -249,8 +249,8 @@ def api_add_case(): case_schema = CaseSchema() try: - case, msg = cases_create(request.get_json()) - return response_success(msg, data=case_schema.dump(case)) + case = cases_create(request.get_json()) + return response_success('Case created', data=case_schema.dump(case)) except BusinessProcessingError as e: return response_error(e.get_message(), data=e.get_data()) @@ -264,6 +264,7 @@ def api_list_case(): @manage_cases_rest_blueprint.route('/manage/cases/update/', methods=['POST']) +@endpoint_deprecated('PUT', '/api/v2/cases/') @ac_api_requires(Permissions.standard_user) def update_case_info(cur_id): if not ac_fast_check_current_user_has_case_access(cur_id, [CaseAccessLevel.full_access]): diff --git a/source/app/blueprints/rest/v2/__init__.py b/source/app/blueprints/rest/v2/__init__.py index cb0fcf014..8d287f83c 100644 --- a/source/app/blueprints/rest/v2/__init__.py +++ b/source/app/blueprints/rest/v2/__init__.py @@ -10,14 +10,14 @@ # Create root /api/v2 blueprint -rest_v2_bp = Blueprint("rest_v2", __name__, url_prefix="/api/v2") +rest_v2_blueprint = Blueprint("rest_v2", __name__, url_prefix="/api/v2") # Register child blueprints -rest_v2_bp.register_blueprint(cases_blueprint) -rest_v2_bp.register_blueprint(auth_blueprint) -rest_v2_bp.register_blueprint(tasks_blueprint) -rest_v2_bp.register_blueprint(iocs_blueprint) -rest_v2_bp.register_blueprint(assets_blueprint) -rest_v2_bp.register_blueprint(alerts_blueprint) -rest_v2_bp.register_blueprint(dashboard_blueprint) +rest_v2_blueprint.register_blueprint(cases_blueprint) +rest_v2_blueprint.register_blueprint(auth_blueprint) +rest_v2_blueprint.register_blueprint(tasks_blueprint) +rest_v2_blueprint.register_blueprint(iocs_blueprint) +rest_v2_blueprint.register_blueprint(assets_blueprint) +rest_v2_blueprint.register_blueprint(alerts_blueprint) +rest_v2_blueprint.register_blueprint(dashboard_blueprint) diff --git a/source/app/blueprints/rest/v2/auth/__init__.py b/source/app/blueprints/rest/v2/auth/__init__.py index 907b27a42..f70e8fbfe 100644 --- a/source/app/blueprints/rest/v2/auth/__init__.py +++ b/source/app/blueprints/rest/v2/auth/__init__.py @@ -26,13 +26,12 @@ from app import db from app import oidc_client from app.blueprints.access_controls import is_authentication_ldap -from app.blueprints.access_controls import is_authentication_oidc, \ - not_authenticated_redirection_url +from app.blueprints.access_controls import is_authentication_oidc +from app.blueprints.access_controls import not_authenticated_redirection_url from app.blueprints.rest.endpoints import response_api_error from app.blueprints.rest.endpoints import response_api_success from app.business.auth import validate_ldap_login, validate_local_login from app.iris_engine.utils.tracker import track_activity -from app.models.authorization import User from app.schema.marshables import UserSchema diff --git a/source/app/blueprints/rest/v2/cases/__init__.py b/source/app/blueprints/rest/v2/cases/__init__.py index a802d7b20..d0a970fb5 100644 --- a/source/app/blueprints/rest/v2/cases/__init__.py +++ b/source/app/blueprints/rest/v2/cases/__init__.py @@ -34,6 +34,7 @@ from app.business.cases import cases_create from app.business.cases import cases_delete from app.datamgmt.case.case_db import get_case +from app.business.cases import cases_update from app.business.errors import BusinessProcessingError from app.datamgmt.manage.manage_cases_db import get_filtered_cases from app.schema.marshables import CaseSchemaForAPIV2 @@ -54,7 +55,7 @@ # Routes -@cases_blueprint.post('', strict_slashes=False) +@cases_blueprint.post('') @ac_api_requires(Permissions.standard_user) def create_case(): """ @@ -62,13 +63,13 @@ def create_case(): """ try: - case, _ = cases_create(request.get_json()) + case = cases_create(request.get_json()) return response_api_created(CaseSchemaForAPIV2().dump(case)) except BusinessProcessingError as e: return response_api_error(e.get_message(), e.get_data()) -@cases_blueprint.get('', strict_slashes=False) +@cases_blueprint.get('') @ac_api_requires() def get_cases() -> Response: """ @@ -123,7 +124,6 @@ def get_cases() -> Response: cases = { 'total': filtered_cases.total, - # TODO should maybe really uniform all return types of paginated list and replace field cases by field data 'data': CaseSchemaForAPIV2().dump(filtered_cases.items, many=True), 'last_page': filtered_cases.pages, 'current_page': filtered_cases.page, @@ -148,6 +148,19 @@ def case_routes_get(identifier): return response_api_success(CaseSchemaForAPIV2().dump(case)) +@cases_blueprint.put('/') +@ac_api_requires(Permissions.standard_user) +def rest_v2_cases_update(identifier): + if not ac_fast_check_current_user_has_case_access(identifier, [CaseAccessLevel.full_access]): + return ac_api_return_access_denied(caseid=identifier) + + try: + case, _ = cases_update(identifier, request.get_json()) + return response_api_success(CaseSchemaForAPIV2().dump(case)) + except BusinessProcessingError as e: + return response_api_error(e.get_message()) + + @cases_blueprint.delete('/') @ac_api_requires(Permissions.standard_user) def case_routes_delete(identifier): diff --git a/source/app/blueprints/rest/v2/cases/assets.py b/source/app/blueprints/rest/v2/cases/assets.py index 35f70a7be..8efa830d3 100644 --- a/source/app/blueprints/rest/v2/cases/assets.py +++ b/source/app/blueprints/rest/v2/cases/assets.py @@ -38,7 +38,7 @@ url_prefix='//assets') -@case_assets_blueprint.get('', strict_slashes=False) +@case_assets_blueprint.get('') @ac_api_requires() def case_list_assets(case_id): """ @@ -65,7 +65,7 @@ def case_list_assets(case_id): return response_api_error(e.get_message()) -@case_assets_blueprint.post('', strict_slashes=False) +@case_assets_blueprint.post('') @ac_api_requires() def add_asset(case_id): """ diff --git a/source/app/blueprints/rest/v2/cases/iocs.py b/source/app/blueprints/rest/v2/cases/iocs.py index ecb215c7e..741634d10 100644 --- a/source/app/blueprints/rest/v2/cases/iocs.py +++ b/source/app/blueprints/rest/v2/cases/iocs.py @@ -37,7 +37,7 @@ url_prefix='//iocs') -@case_iocs_blueprint.get('', strict_slashes=False) +@case_iocs_blueprint.get('') @ac_api_requires() def get_case_iocs(case_id): """ @@ -92,7 +92,7 @@ def get_case_iocs(case_id): return response_api_success(data=iocs) -@case_iocs_blueprint.post('', strict_slashes=False) +@case_iocs_blueprint.post('') @ac_api_requires() def add_ioc_to_case(case_id): """ diff --git a/source/app/blueprints/rest/v2/cases/tasks.py b/source/app/blueprints/rest/v2/cases/tasks.py index 7b2f34a15..2d17e842f 100644 --- a/source/app/blueprints/rest/v2/cases/tasks.py +++ b/source/app/blueprints/rest/v2/cases/tasks.py @@ -19,35 +19,41 @@ from flask import Blueprint from flask import request -from app.blueprints.rest.endpoints import response_api_error, response_api_not_found, response_api_deleted +from app.blueprints.rest.endpoints import response_api_error +from app.blueprints.rest.endpoints import response_api_not_found +from app.blueprints.rest.endpoints import response_api_deleted +from app.blueprints.rest.endpoints import response_api_success from app.blueprints.rest.endpoints import response_api_created from app.blueprints.access_controls import ac_api_return_access_denied from app.blueprints.access_controls import ac_api_requires from app.schema.marshables import CaseTaskSchema -from app.business.errors import BusinessProcessingError, ObjectNotFoundError -from app.business.tasks import tasks_create, tasks_get, tasks_delete +from app.business.errors import BusinessProcessingError +from app.business.errors import ObjectNotFoundError +from app.business.tasks import tasks_create +from app.business.tasks import tasks_get +from app.business.tasks import tasks_delete from app.models.authorization import CaseAccessLevel from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access case_tasks_blueprint = Blueprint('case_tasks', __name__, - url_prefix='//tasks') + url_prefix='//tasks') -@case_tasks_blueprint.post('', strict_slashes=False) +@case_tasks_blueprint.post('') @ac_api_requires() -def add_case_task(case_id): +def add_case_task(case_identifier): """ Add a task to a case. Args: - case_id (int): The Case ID for this task + case_identifier (int): The Case ID for this task """ - if not ac_fast_check_current_user_has_case_access(case_id, [CaseAccessLevel.full_access]): - return ac_api_return_access_denied(caseid=case_id) + if not ac_fast_check_current_user_has_case_access(case_identifier, [CaseAccessLevel.full_access]): + return ac_api_return_access_denied(caseid=case_identifier) task_schema = CaseTaskSchema() try: - _, case = tasks_create(case_id, request.get_json()) + _, case = tasks_create(case_identifier, request.get_json()) return response_api_created(task_schema.dump(case)) except BusinessProcessingError as e: return response_api_error(e.get_message()) @@ -55,45 +61,45 @@ def add_case_task(case_id): @case_tasks_blueprint.get('/') @ac_api_requires() -def get_case_task(case_id, identifier): +def get_case_task(case_identifier, identifier): """ Handles getting a task from a case. Args: - case_id (int): The case ID + case_identifier (int): The case ID identifier (int): The task ID """ try: task = tasks_get(identifier) - if task.task_case_id != case_id: + if task.task_case_id != case_identifier: raise ObjectNotFoundError() if not ac_fast_check_current_user_has_case_access(task.task_case_id, [CaseAccessLevel.read_only, CaseAccessLevel.full_access]): return ac_api_return_access_denied(caseid=task.task_case_id) task_schema = CaseTaskSchema() - return response_api_created(task_schema.dump(task)) + return response_api_success(task_schema.dump(task)) except ObjectNotFoundError: return response_api_not_found() @case_tasks_blueprint.delete('/') @ac_api_requires() -def delete_case_task(case_id, identifier): +def delete_case_task(case_identifier, identifier): """ Handle deleting a task from a case Args: - case_id (int): The case ID + case_identifier (int): The case ID identifier (int): The task ID """ try: task = tasks_get(identifier) - if task.task_case_id != case_id: + if task.task_case_id != case_identifier: raise ObjectNotFoundError() if not ac_fast_check_current_user_has_case_access(task.task_case_id, [CaseAccessLevel.full_access]): @@ -107,4 +113,4 @@ def delete_case_task(case_id, identifier): return response_api_error(e.get_message()) -# TODO: Add task endpoint endpoint \ No newline at end of file +# TODO: Add task update endpoint \ No newline at end of file diff --git a/source/app/blueprints/rest/v2/tasks.py b/source/app/blueprints/rest/v2/tasks.py index 4305276f7..04cddf013 100644 --- a/source/app/blueprints/rest/v2/tasks.py +++ b/source/app/blueprints/rest/v2/tasks.py @@ -19,7 +19,7 @@ from flask import Blueprint from app.blueprints.rest.endpoints import response_api_not_found -from app.blueprints.rest.endpoints import response_api_created +from app.blueprints.rest.endpoints import response_api_success from app.blueprints.rest.endpoints import response_api_deleted from app.blueprints.rest.endpoints import response_api_error from app.blueprints.access_controls import ac_api_requires @@ -49,8 +49,7 @@ def get_case_task(identifier): return ac_api_return_access_denied(caseid=task.task_case_id) task_schema = CaseTaskSchema() - # TODO should be response_api_success => add a test - return response_api_created(task_schema.dump(task)) + return response_api_success(task_schema.dump(task)) except ObjectNotFoundError: return response_api_not_found() diff --git a/source/app/business/cases.py b/source/app/business/cases.py index d9b115596..8ea4b47b4 100644 --- a/source/app/business/cases.py +++ b/source/app/business/cases.py @@ -78,9 +78,9 @@ def cases_exists(identifier): return case_db_exists(identifier) -def cases_create(request_json): +def cases_create(request_data): # TODO remove caseid doesn't seems to be useful for call_modules_hook => remove argument - request_data = call_modules_hook('on_preload_case_create', request_json, None) + request_data = call_modules_hook('on_preload_case_create', request_data, None) case = _load(request_data) @@ -115,7 +115,7 @@ def cases_create(request_json): add_obj_history_entry(case, 'created') track_activity(f'new case "{case.name}" created', caseid=case.case_id, ctx_less=False) - return case, 'Case created' + return case def cases_delete(case_identifier): diff --git a/source/app/views.py b/source/app/views.py index 10926f909..42744167e 100644 --- a/source/app/views.py +++ b/source/app/views.py @@ -98,7 +98,7 @@ from app.blueprints.rest.search_routes import search_rest_blueprint from app.blueprints.graphql.graphql_route import graphql_blueprint -from app.blueprints.rest.v2 import rest_v2_bp +from app.blueprints.rest.v2 import rest_v2_blueprint from app.models.authorization import User from app.post_init import run_post_init @@ -184,7 +184,7 @@ app.register_blueprint(rest_api_blueprint) app.register_blueprint(demo_blueprint) -app.register_blueprint(rest_v2_bp) +app.register_blueprint(rest_v2_blueprint) try: diff --git a/tests/tests_rest_cases.py b/tests/tests_rest_cases.py index 0126ea8a7..559c28aaa 100644 --- a/tests/tests_rest_cases.py +++ b/tests/tests_rest_cases.py @@ -17,6 +17,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. from unittest import TestCase +from uuid import uuid4 from iris import Iris @@ -49,6 +50,15 @@ def test_create_case_should_return_201(self): }) self.assertEqual(201, response.status_code) + def test_create_case_with_spurious_slash_should_return_404(self): + response = self._subject.create('/api/v2/cases/', { + 'case_name': 'name', + 'case_description': 'description', + 'case_customer': 1, + 'case_soc_id': '' + }) + self.assertEqual(404, response.status_code) + def test_create_case_with_missing_name_should_return_400(self): response = self._subject.create('/api/v2/cases', { 'case_description': 'description', @@ -160,6 +170,11 @@ def test_get_case_should_have_field_case_name(self): response = self._subject.get(f'/api/v2/cases/{case_identifier}').json() self.assertIn('case_name', response) + def test_get_case_should_have_field_case_customer_id(self): + case_identifier = self._subject.create_dummy_case() + response = self._subject.get(f'/api/v2/cases/{case_identifier}').json() + self.assertIn('case_customer_id', response) + def test_create_case_should_return_data_with_case_customer_when_case_customer_is_an_empty_string(self): body = { 'case_name': 'case name', @@ -169,3 +184,53 @@ def test_create_case_should_return_data_with_case_customer_when_case_customer_is } response = self._subject.create('/api/v2/cases', body).json() self.assertIn('case_customer', response['data']) + + def test_update_case_should_not_fail(self): + identifier = self._subject.create_dummy_case() + response = self._subject.update(f'/api/v2/cases/{identifier}', { 'case_name': 'new name' }) + self.assertEqual(200, response.status_code) + + def test_update_case_should_allow_to_update_severity(self): + identifier = self._subject.create_dummy_case() + response = self._subject.update(f'/api/v2/cases/{identifier}', { 'severity_id': 5 }).json() + self.assertEqual(5, response['severity_id']) + + def test_update_case_should_allow_to_update_classification(self): + identifier = self._subject.create_dummy_case() + response = self._subject.update(f'/api/v2/cases/{identifier}', { 'classification_id': 3 }).json() + self.assertEqual(3, response['classification_id']) + + def test_update_case_should_allow_to_update_owner(self): + user = self._subject.create_dummy_user() + identifier = self._subject.create_dummy_case() + response = self._subject.update(f'/api/v2/cases/{identifier}', { 'owner_id': user.get_identifier() }).json() + self.assertEqual(user.get_identifier(), response['owner']['id']) + + def test_update_case_should_allow_to_update_state(self): + identifier = self._subject.create_dummy_case() + response = self._subject.update(f'/api/v2/cases/{identifier}', { 'state_id': 2 }).json() + self.assertEqual(2, response['state']['state_id']) + + def test_update_case_should_allow_to_update_status(self): + identifier = self._subject.create_dummy_case() + response = self._subject.update(f'/api/v2/cases/{identifier}', { 'status_id': 2 }).json() + self.assertEqual(2, response['status_id']) + + def test_update_case_should_allow_to_update_customer(self): + identifier = self._subject.create_dummy_case() + response = self._subject.create('/manage/customers/add', { 'customer_name': f'customer{uuid4()}'}).json() + customer_identifier = response['data']['customer_id'] + response = self._subject.update(f'/api/v2/cases/{identifier}', {'case_customer': customer_identifier}).json() + self.assertEqual(customer_identifier, response['case_customer_id']) + + def test_update_case_should_allow_to_update_reviewer(self): + identifier = self._subject.create_dummy_case() + user = self._subject.create_dummy_user() + response = self._subject.update(f'/api/v2/cases/{identifier}', {'reviewer_id': user.get_identifier()}).json() + self.assertEqual(user.get_identifier(), response['reviewer_id']) + + def test_update_case_should_allow_to_update_tags(self): + identifier = self._subject.create_dummy_case() + response = self._subject.update(f'/api/v2/cases/{identifier}', {'case_tags': 'tag1,tag2'}).json() + self.assertEqual('tag1,tag2', response['case_tags']) + diff --git a/tests/tests_rest_tasks.py b/tests/tests_rest_tasks.py index 04da681d5..ce3c7ac7b 100644 --- a/tests/tests_rest_tasks.py +++ b/tests/tests_rest_tasks.py @@ -44,7 +44,14 @@ def test_add_task_with_missing_task_title_identifier_should_return_400(self): response = self._subject.create(f'/api/v2/cases/{case_identifier}/tasks', body) self.assertEqual(400, response.status_code) - def test_get_task_should_return_201(self): + def test_create_case_with_spurious_slash_should_return_404(self): + case_identifier = self._subject.create_dummy_case() + body = {'task_assignees_id': [1], 'task_description': '', 'task_status_id': 1, 'task_tags': '', + 'task_title': 'dummy title', 'custom_attributes': {}} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/tasks/', body) + self.assertEqual(404, response.status_code) + + def test_get_task_should_return_200(self): case_identifier = self._subject.create_dummy_case() body = {'task_assignees_id': [2], 'task_description': '', 'task_status_id': 1, 'task_tags': '', 'task_title': 'dummy title', @@ -52,7 +59,7 @@ def test_get_task_should_return_201(self): response = self._subject.create(f'/api/v2/cases/{case_identifier}/tasks', body).json() task_identifier = response['id'] response = self._subject.get(f'/api/v2/cases/{case_identifier}/tasks/{task_identifier}') - self.assertEqual(201, response.status_code) + self.assertEqual(200, response.status_code) def test_get_task_with_missing_task_identifier_should_return_error(self): case_identifier = self._subject.create_dummy_case()