diff --git a/scanpipe/policies.py b/scanpipe/policies.py index d0ea94e5c3..f1d817112c 100644 --- a/scanpipe/policies.py +++ b/scanpipe/policies.py @@ -35,8 +35,8 @@ def load_policies_yaml(policies_yaml): def load_policies_file(policies_file, validate=True): """ - Load provided ``policies_file`` into a Python dictionary. - The policies format is validated by default. + Load provided ``policies_file`` into a Python dictionary. The policies format + is validated by default to ensure at least one policy type exists. """ policies_dict = load_policies_yaml(policies_yaml=policies_file.read_text()) if validate: @@ -45,13 +45,23 @@ def load_policies_file(policies_file, validate=True): def validate_policies(policies_dict): - """Return True if the provided ``policies_dict`` is valid.""" + """ + Return True if the provided ``policies_dict`` contains at least + one supported policy type. + """ if not isinstance(policies_dict, dict): raise ValidationError("The `policies_dict` argument must be a dictionary.") - if "license_policies" not in policies_dict: + supported_keys = { + "license_policies", + "license_clarity_thresholds", + "scorecard_score_thresholds", + } + + if not any(key in policies_dict for key in supported_keys): raise ValidationError( - "The `license_policies` key is missing from provided policies data." + "At least one of the following policy types must be present: " + f"{', '.join(sorted(supported_keys))}" ) return True diff --git a/scanpipe/tests/test_forms.py b/scanpipe/tests/test_forms.py index 98e0f83ff4..0dc5bc9077 100644 --- a/scanpipe/tests/test_forms.py +++ b/scanpipe/tests/test_forms.py @@ -219,7 +219,9 @@ def test_scanpipe_forms_project_settings_form_policies(self): self.assertFalse(form.is_valid()) expected = { "policies": [ - "The `license_policies` key is missing from provided policies data." + "At least one of the following policy types must be present: " + "license_clarity_thresholds, license_policies, " + "scorecard_score_thresholds" ] } self.assertEqual(expected, form.errors) diff --git a/scanpipe/tests/test_policies.py b/scanpipe/tests/test_policies.py index fc7be2f8d3..ba7800a8a0 100644 --- a/scanpipe/tests/test_policies.py +++ b/scanpipe/tests/test_policies.py @@ -66,7 +66,11 @@ def test_scanpipe_policies_validate_policies(self): with self.assertRaisesMessage(ValidationError, error_msg): validate_policies(policies_dict) - error_msg = "The `license_policies` key is missing from provided policies data." + error_msg = ( + "At least one of the following policy types must be present: " + "license_clarity_thresholds, license_policies, " + "scorecard_score_thresholds" + ) policies_dict = {} with self.assertRaisesMessage(ValidationError, error_msg): validate_policies(policies_dict)