diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4f777e6444..35b606beff 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,12 @@ Changelog ========= +v35.4.2 (unreleased) +-------------------- + +- Add arguments support for the reset action in REST API. + https://github.com/aboutcode-org/scancode.io/issues/1948 + v35.4.1 (2025-10-24) -------------------- diff --git a/scanpipe/api/views.py b/scanpipe/api/views.py index b07ea8772f..8288b2858f 100644 --- a/scanpipe/api/views.py +++ b/scanpipe/api/views.py @@ -21,6 +21,7 @@ # Visit https://github.com/aboutcode-org/scancode.io for support and download. import json +import logging from django.apps import apps from django.core.exceptions import ObjectDoesNotExist @@ -58,9 +59,22 @@ from scanpipe.pipes.compliance import get_project_compliance_alerts from scanpipe.views import project_results_json_response +logger = logging.getLogger(__name__) scanpipe_app = apps.get_app_config("scanpipe") +class ErrorResponse(Response): + def __init__(self, message, status_code=status.HTTP_400_BAD_REQUEST, **kwargs): + # If message is already a dict, use it as-is + if isinstance(message, dict): + data = message + else: + # Otherwise, wrap string in {"status": message} + data = {"status": message} + + super().__init__(data=data, status=status_code, **kwargs) + + class ProjectFilterSet(django_filters.rest_framework.FilterSet): name = django_filters.CharFilter() name__contains = django_filters.CharFilter( @@ -176,8 +190,7 @@ def results_download(self, request, *args, **kwargs): elif format == "all_outputs": output_file = output.to_all_outputs(project) else: - message = {"status": f"Format {format} not supported."} - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse(f"Format {format} not supported.") filename = output.safe_filename(f"scancodeio_{project.name}_{output_file.name}") return FileResponse( @@ -196,8 +209,7 @@ def summary(self, request, *args, **kwargs): summary_file = project.get_latest_output(filename="summary") if not summary_file: - message = {"error": "Summary file not available"} - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse({"error": "Summary file not available"}) summary_json = json.loads(summary_file.read_text()) return Response(summary_json) @@ -224,14 +236,14 @@ def report(self, request, *args, **kwargs): ), "choices": ", ".join(model_choices), } - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse(message) if model not in model_choices: message = { "error": f"{model} is not on of the valid choices", "choices": ", ".join(model_choices), } - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse(message) output_file = output.get_xlsx_report( project_qs=project_qs, @@ -254,8 +266,7 @@ def get_filtered_response( """ filterset = filterset_class(data=request.GET, queryset=queryset) if not filterset.is_valid(): - message = {"errors": filterset.errors} - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse({"errors": filterset.errors}) queryset = filterset.qs paginated_qs = self.paginate_queryset(queryset) @@ -313,14 +324,12 @@ def file_content(self, request, *args, **kwargs): try: codebase_resource = codebase_resources.get(path=path) except ObjectDoesNotExist: - message = {"status": "Resource not found. Use ?path="} - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse("Resource not found. Use ?path=") try: file_content = codebase_resource.file_content except OSError: - message = {"status": "File not available"} - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse("File not available") return Response({"file_content": file_content}) @@ -339,32 +348,29 @@ def add_pipeline(self, request, *args, **kwargs): {"status": "Pipeline added."}, status=status.HTTP_201_CREATED ) - message = {"status": f"{pipeline} is not a valid pipeline."} - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse(f"{pipeline} is not a valid pipeline.") message = { "status": "Pipeline required.", "pipelines": list(scanpipe_app.pipelines.keys()), } - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse(message) @action(detail=True, methods=["get", "post"]) def add_input(self, request, *args, **kwargs): project = self.get_object() if not project.can_change_inputs: - message = { - "status": "Cannot add inputs once a pipeline has started to execute." - } - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse( + "Cannot add inputs once a pipeline has started to execute." + ) upload_file = request.data.get("upload_file") upload_file_tag = request.data.get("upload_file_tag", "") input_urls = request.data.get("input_urls", []) if not (upload_file or input_urls): - message = {"status": "upload_file or input_urls required."} - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse("upload_file or input_urls required.") if upload_file: project.add_upload(upload_file, tag=upload_file_tag) @@ -396,13 +402,13 @@ def add_webhook(self, request, *args, **kwargs): ) # Return validation errors - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse(serializer.errors) def destroy(self, request, *args, **kwargs): try: return super().destroy(request, *args, **kwargs) - except RunInProgressError as error: - return Response({"status": str(error)}, status=status.HTTP_400_BAD_REQUEST) + except RunInProgressError: + return ErrorResponse("Cannot delete project while a run is in progress.") @action(detail=True, methods=["get", "post"]) def archive(self, request, *args, **kwargs): @@ -423,10 +429,10 @@ def archive(self, request, *args, **kwargs): remove_codebase=request.data.get("remove_codebase"), remove_output=request.data.get("remove_output"), ) - except RunInProgressError as error: - return Response(error, status=status.HTTP_400_BAD_REQUEST) - else: - return Response({"status": f"The project {project} has been archived."}) + except RunInProgressError: + return ErrorResponse("Cannot archive project while a run is in progress.") + + return Response({"status": f"The project {project} has been archived."}) @action(detail=True, methods=["get", "post"]) def reset(self, request, *args, **kwargs): @@ -437,13 +443,15 @@ def reset(self, request, *args, **kwargs): return Response({"status": message}) try: - project.reset(keep_input=True) - except RunInProgressError as error: - return Response(error, status=status.HTTP_400_BAD_REQUEST) - else: - message = ( - f"All data, except inputs, for the {project} project have been removed." + project.reset( + keep_input=request.data.get("keep_input", True), + restore_pipelines=request.data.get("restore_pipelines", False), + execute_now=request.data.get("execute_now", False), ) + except RunInProgressError: + return ErrorResponse("Cannot reset project while a run is in progress.") + else: + message = f"The {project} project has been reset." return Response({"status": message}) @action(detail=True, methods=["get"]) @@ -455,8 +463,7 @@ def outputs(self, request, *args, **kwargs): if file_path.exists(): return FileResponse(file_path.open("rb")) - message = {"status": f"Output file {filename} not found"} - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse(f"Output file {filename} not found") action_url = self.reverse_action(self.outputs.url_name, args=[project.pk]) output_data = [ @@ -531,14 +538,11 @@ class RunViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet): def start_pipeline(self, request, *args, **kwargs): run = self.get_object() if run.task_end_date: - message = {"status": "Pipeline already executed."} - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse("Pipeline already executed.") elif run.task_start_date: - message = {"status": "Pipeline already started."} - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse("Pipeline already started.") elif run.task_id: - message = {"status": "Pipeline already queued."} - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse("Pipeline already queued.") transaction.on_commit(run.start) @@ -549,8 +553,7 @@ def stop_pipeline(self, request, *args, **kwargs): run = self.get_object() if run.status != run.Status.RUNNING: - message = {"status": "Pipeline is not running."} - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse("Pipeline is not running.") run.stop_task() return Response({"status": f"Pipeline {run.pipeline_name} stopped."}) @@ -560,8 +563,7 @@ def delete_pipeline(self, request, *args, **kwargs): run = self.get_object() if run.status not in [run.Status.NOT_STARTED, run.Status.QUEUED]: - message = {"status": "Only non started or queued pipelines can be deleted."} - return Response(message, status=status.HTTP_400_BAD_REQUEST) + return ErrorResponse("Only non started or queued pipelines can be deleted.") run.delete_task() return Response({"status": f"Pipeline {run.pipeline_name} deleted."}) diff --git a/scanpipe/tests/test_api.py b/scanpipe/tests/test_api.py index 8362e6a3ff..9894473757 100644 --- a/scanpipe/tests/test_api.py +++ b/scanpipe/tests/test_api.py @@ -909,10 +909,7 @@ def test_scanpipe_api_project_action_delete(self): response = self.csrf_client.delete(self.project1_detail_url) self.assertEqual(status.HTTP_400_BAD_REQUEST, response.status_code) - expected = ( - "Cannot execute this action until all associated pipeline runs are " - "completed." - ) + expected = "Cannot delete project while a run is in progress." self.assertEqual(expected, response.data["status"]) run.set_task_ended(exitcode=0) @@ -962,10 +959,7 @@ def test_scanpipe_api_project_action_reset(self): response = self.csrf_client.post(url) self.assertEqual(status.HTTP_200_OK, response.status_code) - expected = { - "status": "All data, except inputs, for the Analysis project have been " - "removed." - } + expected = {"status": "The Analysis project has been reset."} self.assertEqual(expected, response.data) self.assertEqual(0, self.project1.runs.count()) self.assertEqual(0, self.project1.codebaseresources.count())