Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions invenio_collections/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
# Invenio-Collections is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.
"""Collections programmatic API."""

from invenio_records.systemfields import ModelField
from luqum.parser import parser as luqum_parser
from werkzeug.utils import cached_property

from .errors import CollectionNotFound, CollectionTreeNotFound, InvalidQuery
from .errors import CollectionNotFound, CollectionTreeNotFound
from .models import Collection as CollectionModel
from .models import CollectionTree as CollectionTreeModel

Expand All @@ -34,14 +34,6 @@ def __init__(self, model=None, max_depth=2):
self.model = model
self.max_depth = max_depth

@classmethod
def validate_query(cls, query):
"""Validate the collection query."""
try:
luqum_parser.parse(query)
except Exception:
raise InvalidQuery()

@classmethod
def create(cls, slug, title, query, ctree=None, parent=None, order=None, depth=2):
"""Create a new collection."""
Expand All @@ -55,7 +47,6 @@ def create(cls, slug, title, query, ctree=None, parent=None, order=None, depth=2
else:
raise ValueError("Either parent or ctree must be set.")

Collection.validate_query(query)
return cls(
cls.model_cls.create(
slug=slug,
Expand Down Expand Up @@ -100,8 +91,6 @@ def read_all(cls, depth=2):

def update(self, **kwargs):
"""Update the collection."""
if "search_query" in kwargs:
Collection.validate_query(kwargs["search_query"])
self.model.update(**kwargs)
return self

Expand Down
76 changes: 71 additions & 5 deletions invenio_collections/services/schema.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,87 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
# Copyright (C) 2025 Ubiquity Press
#
# Invenio-Collections is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.
"""Collections schema."""

from marshmallow import Schema, fields
"""Collections schemas."""

import re

from invenio_i18n import lazy_gettext as _
from luqum.exceptions import ParseError
from luqum.parser import parser as luqum_parser
from marshmallow import Schema, ValidationError, fields, validate
from marshmallow_utils.fields import SanitizedUnicode


def _not_blank(**kwargs):
"""Returns a non-blank validation rule."""
max_ = kwargs.get("max", "")
return validate.Length(
error=_(
"Field cannot be blank or longer than {max_} characters.".format(max_=max_)
),
min=1,
**kwargs,
)


class CollectionTreeSchema(Schema):
"""Collection tree schema."""

slug = SanitizedUnicode(
required=True,
validate=[
_not_blank(max=255),
validate.Regexp(
r"^[-\w]+$",
flags=re.ASCII,
error=_(
"The identifier should contain only letters, numbers, or dashes."
),
),
],
)
title = SanitizedUnicode(
validate=[_not_blank(max=255)],
)
order = fields.Int()
id = fields.Int(dump_only=True)
community_id = fields.Str(dump_only=True)


def validate_search_query(query):
"""Validate a search query using luqum parser."""
try:
luqum_parser.parse(query)
except ParseError as e:
raise ValidationError(str(e)) from e


class CollectionSchema(Schema):
"""Collection schema."""

slug = fields.Str()
title = fields.Str()
slug = SanitizedUnicode(
Copy link
Member Author

@egabancho egabancho Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of these fields should be required, like slug, but that would be a breaking change (more than adding these validations) and require a different schema for update operations (example).
I am open to suggestions 😇

validate=[
_not_blank(max=255),
validate.Regexp(
r"^[-\w]+$",
flags=re.ASCII,
error=_(
"The identifier should contain only letters, numbers, or dashes."
),
),
],
)
title = SanitizedUnicode(
validate=[_not_blank(max=255)],
)

depth = fields.Int(dump_only=True)
order = fields.Int()
id = fields.Int(dump_only=True)
num_records = fields.Int()
search_query = fields.Str(load_only=True)
search_query = fields.Str(validate=[validate_search_query])
54 changes: 54 additions & 0 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2025 Ubiquity Press
#
# Invenio-RDM is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.
#
"""Test suite for the collections schemas."""

import pytest
from marshmallow import ValidationError

from invenio_collections.services.schema import CollectionSchema


def test_collection_schema_validation():
"""Test search query validation."""
valid_input = {
"slug": "col",
"title": "Test collection",
"order": 0,
"search_query": "*:*",
}

schema = CollectionSchema()
collection = schema.load(valid_input)
assert valid_input == collection == schema.dump(collection)


def test_collection_schema_fail():
"""Test schema validation errors."""
input = {
"slug": "col",
"title": "Test collection",
"order": 0,
"search_query": "*:*",
}
schema = CollectionSchema()
with pytest.raises(ValidationError) as exc_info:
input["search_query"] = "custom_fields.journal:journal.volume:'2025'"
schema.load(input)
assert exc_info.value.args[0] == {
"search_query": ["Illegal character ''2025'' at position 37"]
}

# Set back query
input["search_query"] = "*:*"

with pytest.raises(ValidationError) as exc_info:
input["slug"] = "not valid"
schema.load(input)
assert exc_info.value.args[0] == {
"slug": ["The identifier should contain only letters, numbers, or dashes."]
}
13 changes: 9 additions & 4 deletions tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def test_collections_results(
"order": c0.order,
"slug": "collection-1",
"title": "Collection 1",
"search_query": "metadata.title:foo",
},
c1.id: {
"children": [],
Expand All @@ -186,6 +187,7 @@ def test_collections_results(
"order": c1.order,
"slug": "collection-2",
"title": "Collection 2",
"search_query": "metadata.title:bar",
},
}
assert not list(dictdiffer.diff(expected, r_dict))
Expand Down Expand Up @@ -218,6 +220,7 @@ def test_collections_results(
"order": c0.order,
"slug": "collection-1",
"title": "Collection 1",
"search_query": "metadata.title:foo",
},
c1.id: {
"children": [c3.id],
Expand All @@ -231,6 +234,7 @@ def test_collections_results(
"order": c1.order,
"slug": "collection-2",
"title": "Collection 2",
"search_query": "metadata.title:bar",
},
c3.id: {
"children": [],
Expand All @@ -244,6 +248,7 @@ def test_collections_results(
"order": c3.order,
"slug": "collection-3",
"title": "Collection 3",
"search_query": "metadata.title:baz",
},
}

Expand All @@ -259,28 +264,28 @@ def test_update(app, db, add_collections, collections_service, community_owner):
collections_service.update(
community_owner.identity,
c0.id,
data={"slug": "New slug"},
data={"slug": "new-slug"},
)

res = collections_service.read(
identity=community_owner.identity,
id_=c0.id,
)

assert res.to_dict()[c0.id]["slug"] == "New slug"
assert res.to_dict()[c0.id]["slug"] == "new-slug"

# Update by object
collections_service.update(
community_owner.identity,
c0,
data={"slug": "New slug 2"},
data={"slug": "new-slug-2"},
)

res = collections_service.read(
identity=community_owner.identity,
id_=c0.id,
)
assert res.to_dict()[c0.id]["slug"] == "New slug 2"
assert res.to_dict()[c0.id]["slug"] == "new-slug-2"


def test_read_many(app, db, add_collections, collections_service, community_owner):
Expand Down
Loading