diff --git a/invenio_collections/api.py b/invenio_collections/api.py index aaa7f58..e80e07a 100644 --- a/invenio_collections/api.py +++ b/invenio_collections/api.py @@ -5,11 +5,11 @@ # Invenio-Collections is free software; you can redistribute it and/or modify # it under the terms of the MIT License; see LICENSE file for more details. """Collections programmatic API.""" + from invenio_records.systemfields import ModelField -from luqum.parser import parser as luqum_parser from werkzeug.utils import cached_property -from .errors import CollectionNotFound, CollectionTreeNotFound, InvalidQuery +from .errors import CollectionNotFound, CollectionTreeNotFound from .models import Collection as CollectionModel from .models import CollectionTree as CollectionTreeModel @@ -34,14 +34,6 @@ def __init__(self, model=None, max_depth=2): self.model = model self.max_depth = max_depth - @classmethod - def validate_query(cls, query): - """Validate the collection query.""" - try: - luqum_parser.parse(query) - except Exception: - raise InvalidQuery() - @classmethod def create(cls, slug, title, query, ctree=None, parent=None, order=None, depth=2): """Create a new collection.""" @@ -55,7 +47,6 @@ def create(cls, slug, title, query, ctree=None, parent=None, order=None, depth=2 else: raise ValueError("Either parent or ctree must be set.") - Collection.validate_query(query) return cls( cls.model_cls.create( slug=slug, @@ -100,8 +91,6 @@ def read_all(cls, depth=2): def update(self, **kwargs): """Update the collection.""" - if "search_query" in kwargs: - Collection.validate_query(kwargs["search_query"]) self.model.update(**kwargs) return self diff --git a/invenio_collections/services/schema.py b/invenio_collections/services/schema.py index 8d3e604..13ca3d3 100644 --- a/invenio_collections/services/schema.py +++ b/invenio_collections/services/schema.py @@ -1,21 +1,87 @@ # -*- coding: utf-8 -*- # # Copyright (C) 2024 CERN. +# Copyright (C) 2025 Ubiquity Press # # Invenio-Collections is free software; you can redistribute it and/or modify # it under the terms of the MIT License; see LICENSE file for more details. -"""Collections schema.""" -from marshmallow import Schema, fields +"""Collections schemas.""" + +import re + +from invenio_i18n import lazy_gettext as _ +from luqum.exceptions import ParseError +from luqum.parser import parser as luqum_parser +from marshmallow import Schema, ValidationError, fields, validate +from marshmallow_utils.fields import SanitizedUnicode + + +def _not_blank(**kwargs): + """Returns a non-blank validation rule.""" + max_ = kwargs.get("max", "") + return validate.Length( + error=_( + "Field cannot be blank or longer than {max_} characters.".format(max_=max_) + ), + min=1, + **kwargs, + ) + + +class CollectionTreeSchema(Schema): + """Collection tree schema.""" + + slug = SanitizedUnicode( + required=True, + validate=[ + _not_blank(max=255), + validate.Regexp( + r"^[-\w]+$", + flags=re.ASCII, + error=_( + "The identifier should contain only letters, numbers, or dashes." + ), + ), + ], + ) + title = SanitizedUnicode( + validate=[_not_blank(max=255)], + ) + order = fields.Int() + id = fields.Int(dump_only=True) + community_id = fields.Str(dump_only=True) + + +def validate_search_query(query): + """Validate a search query using luqum parser.""" + try: + luqum_parser.parse(query) + except ParseError as e: + raise ValidationError(str(e)) from e class CollectionSchema(Schema): """Collection schema.""" - slug = fields.Str() - title = fields.Str() + slug = SanitizedUnicode( + validate=[ + _not_blank(max=255), + validate.Regexp( + r"^[-\w]+$", + flags=re.ASCII, + error=_( + "The identifier should contain only letters, numbers, or dashes." + ), + ), + ], + ) + title = SanitizedUnicode( + validate=[_not_blank(max=255)], + ) + depth = fields.Int(dump_only=True) order = fields.Int() id = fields.Int(dump_only=True) num_records = fields.Int() - search_query = fields.Str(load_only=True) + search_query = fields.Str(validate=[validate_search_query]) diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 0000000..516637f --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2025 Ubiquity Press +# +# Invenio-RDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +# +"""Test suite for the collections schemas.""" + +import pytest +from marshmallow import ValidationError + +from invenio_collections.services.schema import CollectionSchema + + +def test_collection_schema_validation(): + """Test search query validation.""" + valid_input = { + "slug": "col", + "title": "Test collection", + "order": 0, + "search_query": "*:*", + } + + schema = CollectionSchema() + collection = schema.load(valid_input) + assert valid_input == collection == schema.dump(collection) + + +def test_collection_schema_fail(): + """Test schema validation errors.""" + input = { + "slug": "col", + "title": "Test collection", + "order": 0, + "search_query": "*:*", + } + schema = CollectionSchema() + with pytest.raises(ValidationError) as exc_info: + input["search_query"] = "custom_fields.journal:journal.volume:'2025'" + schema.load(input) + assert exc_info.value.args[0] == { + "search_query": ["Illegal character ''2025'' at position 37"] + } + + # Set back query + input["search_query"] = "*:*" + + with pytest.raises(ValidationError) as exc_info: + input["slug"] = "not valid" + schema.load(input) + assert exc_info.value.args[0] == { + "slug": ["The identifier should contain only letters, numbers, or dashes."] + } diff --git a/tests/test_service.py b/tests/test_service.py index b673a28..dbd50fd 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -173,6 +173,7 @@ def test_collections_results( "order": c0.order, "slug": "collection-1", "title": "Collection 1", + "search_query": "metadata.title:foo", }, c1.id: { "children": [], @@ -186,6 +187,7 @@ def test_collections_results( "order": c1.order, "slug": "collection-2", "title": "Collection 2", + "search_query": "metadata.title:bar", }, } assert not list(dictdiffer.diff(expected, r_dict)) @@ -218,6 +220,7 @@ def test_collections_results( "order": c0.order, "slug": "collection-1", "title": "Collection 1", + "search_query": "metadata.title:foo", }, c1.id: { "children": [c3.id], @@ -231,6 +234,7 @@ def test_collections_results( "order": c1.order, "slug": "collection-2", "title": "Collection 2", + "search_query": "metadata.title:bar", }, c3.id: { "children": [], @@ -244,6 +248,7 @@ def test_collections_results( "order": c3.order, "slug": "collection-3", "title": "Collection 3", + "search_query": "metadata.title:baz", }, } @@ -259,7 +264,7 @@ def test_update(app, db, add_collections, collections_service, community_owner): collections_service.update( community_owner.identity, c0.id, - data={"slug": "New slug"}, + data={"slug": "new-slug"}, ) res = collections_service.read( @@ -267,20 +272,20 @@ def test_update(app, db, add_collections, collections_service, community_owner): id_=c0.id, ) - assert res.to_dict()[c0.id]["slug"] == "New slug" + assert res.to_dict()[c0.id]["slug"] == "new-slug" # Update by object collections_service.update( community_owner.identity, c0, - data={"slug": "New slug 2"}, + data={"slug": "new-slug-2"}, ) res = collections_service.read( identity=community_owner.identity, id_=c0.id, ) - assert res.to_dict()[c0.id]["slug"] == "New slug 2" + assert res.to_dict()[c0.id]["slug"] == "new-slug-2" def test_read_many(app, db, add_collections, collections_service, community_owner):