diff --git a/api/scpca_portal/filter.py b/api/scpca_portal/filter.py new file mode 100644 index 000000000..1f85e38aa --- /dev/null +++ b/api/scpca_portal/filter.py @@ -0,0 +1,89 @@ +from django.contrib.postgres.fields import ArrayField +from django.db import models + +import django_filters +from django_filters import rest_framework as filters + +# Lookup expressions per field type +FILTER_LOOKUPS = { + models.BigIntegerField: ["exact", "gte", "lte", "gt", "lt", "in"], + models.BooleanField: ["exact"], + models.CharField: ["exact", "icontains", "istartswith"], + models.DateTimeField: ["exact", "gte", "lte", "date"], + models.EmailField: ["exact", "icontains", "istartswith"], + models.IntegerField: ["exact", "gte", "lte", "gt", "lt", "in"], + models.JSONField: ["exact", "in"], + models.PositiveIntegerField: ["exact", "gte", "lte", "gt", "lt", "in"], + models.TextField: ["exact", "icontains"], +} + + +# Custom filter for Postgres ArrayFields +class ArrayFieldContainsFilter(django_filters.BaseInFilter, django_filters.CharFilter): + """ + Accepts comma-separated values and applies icontains per term (AND logic). + e.g. ?diagnoses=Neuroblastoma,Glioma matches projects containing both. + NOTE: Swap the loop for Q objects if you want OR logic instead. + """ + + def filter(self, qs, value): + if not value: + return qs + for term in value: + qs = qs.filter(**{f"{self.field_name}__icontains": term.strip()}) + return qs + + +# Filterset Factory +def build_auto_filterset( + model, + auto_fields: list[str] = None, + extra_fields: dict[str, list[str]] = None, + extra_filters: dict = None, +): + """ + Introspects a model and builds a FilterSet with sensible lookup expressions + per field type. ArrayFields get icontains via ArrayFieldContainsFilter. + Args: + model: The Django model class to build a FilterSet for. + auto_fields: Optional allowlist of field names. If omitted, all + supported field types are included. Always use this + to keep your public API surface intentional. + extra_fields: Additional model fields included in the public API + e.g. {"project__scpca_id": ["exact"]}. + extra_filters: Optional dict of additional filter instances to mix in, + excluded from the public API + e.g. {"in_stock": MyCustomFilter(...)}. + """ + + declared_filters = {} + meta_fields = {} + + for field in model._meta.get_fields(): + if field.is_relation and (field.one_to_many or field.many_to_many): + # Skip reverse relations and ManyToMany + continue + if auto_fields and field.name not in auto_fields: + continue + + # ArrayField: use custom filter, one filter per field + if isinstance(field, ArrayField): + declared_filters[field.name] = ArrayFieldContainsFilter(field_name=field.name) + continue + + # Standard field types: use dict-style meta fields for multi-lookup support + for field_type, lookups in FILTER_LOOKUPS.items(): + if isinstance(field, field_type): + meta_fields[field.name] = lookups + break + + if extra_fields: + meta_fields.update(extra_fields) + + if extra_filters: + declared_filters.update(extra_filters) + + meta = type("Meta", (), {"model": model, "fields": meta_fields}) + attrs = {"Meta": meta, **declared_filters} + + return type(f"{model.__name__}AutoFilterSet", (filters.FilterSet,), attrs) diff --git a/api/scpca_portal/test/test_filter.py b/api/scpca_portal/test/test_filter.py new file mode 100644 index 000000000..fc0622b82 --- /dev/null +++ b/api/scpca_portal/test/test_filter.py @@ -0,0 +1,81 @@ +from django.test import TestCase + +from scpca_portal import filter +from scpca_portal.filter import ArrayFieldContainsFilter +from scpca_portal.models import Sample + + +class FilterTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.SampleFilterSet = filter.build_auto_filterset( + Sample, + auto_fields=[ + "scpca_id", # TextField + "has_cite_seq_data", # BooleanField + "technologies", # ArrayField + "sample_cell_count_estimate", # IntegerField + "updated_at", # DateTimeField + ], + extra_fields={"project__scpca_id": ["exact"]}, + ) + + def test_all_included_fields(self): + actual_fields = self.SampleFilterSet.base_filters.keys() + expected_fields = [ + "scpca_id", + "has_cite_seq_data", + "technologies", + "sample_cell_count_estimate", + "updated_at", + # extra_fields + "project__scpca_id", + ] + + for expected_field in expected_fields: + self.assertIn(expected_field, actual_fields) + + def test_array_fields(self): + # Should be an instance of ArrayFieldContainsFilter + array_field_filter = self.SampleFilterSet.base_filters["technologies"] + self.assertIsInstance(array_field_filter, ArrayFieldContainsFilter) + + def test_boolean_fields(self): + # Should support "exact" + actual_fields = self.SampleFilterSet.base_filters.keys() + expected_fields = ["has_cite_seq_data"] + + for expected_field in expected_fields: + self.assertIn(expected_field, actual_fields) + + def test_datetime_fields(self): + # Should support "exact", "gte", "lte", and "date" + actual_fields = self.SampleFilterSet.base_filters.keys() + expected_fields = ["updated_at", "updated_at__gte", "updated_at__lte", "updated_at__date"] + + for expected_field in expected_fields: + self.assertIn(expected_field, actual_fields) + + def test_integer_fields(self): + # Should support"exact", "gte", "lte", "gt", "lt", and "in" + actual_fields = self.SampleFilterSet.base_filters.keys() + + expected_fields = [ + "sample_cell_count_estimate", + "sample_cell_count_estimate__gte", + "sample_cell_count_estimate__lte", + "sample_cell_count_estimate__gt", + "sample_cell_count_estimate__lt", + "sample_cell_count_estimate__in", + ] + + for expected_field in expected_fields: + self.assertIn(expected_field, actual_fields) + + def test_text_fields(self): + # Should support "exact", "icontains" + actual_fields = self.SampleFilterSet.base_filters.keys() + expected_fields = ["scpca_id", "scpca_id__icontains"] + + for expected_field in expected_fields: + self.assertIn(expected_field, actual_fields) diff --git a/api/scpca_portal/views/project.py b/api/scpca_portal/views/project.py index 33d261e4c..f303a4637 100644 --- a/api/scpca_portal/views/project.py +++ b/api/scpca_portal/views/project.py @@ -1,50 +1,46 @@ -from django.contrib.postgres.fields import ArrayField from rest_framework import viewsets -from django_filters import rest_framework as filters from drf_spectacular.utils import extend_schema from rest_framework_extensions.mixins import NestedViewSetMixin +from scpca_portal import filter from scpca_portal.models import Project from scpca_portal.serializers import ProjectDetailSerializer, ProjectSerializer - -class ProjectFilterSet(filters.FilterSet): - """ - Custom FilterSet to support ArrayField. - """ - - class Meta: - model = Project - fields = [ - "scpca_id", - "pi_name", - "has_bulk_rna_seq", - "has_cite_seq_data", - "has_multiplexed_data", - "has_single_cell_data", - "has_spatial_data", - "includes_cell_lines", - "includes_xenografts", - "diagnoses", - "seq_units", - "modalities", - "organisms", - "technologies", - "disease_timings", - "human_readable_pi_name", - "title", - "abstract", - ] - - filter_overrides = { - ArrayField: { - "filter_class": filters.CharFilter, - "extra": lambda f: { - "lookup_expr": "icontains", - }, - } - } +ProjectFilterSet = filter.build_auto_filterset( + Project, + auto_fields=[ + "scpca_id", + "pi_name", + "has_bulk_rna_seq", + "has_cite_seq_data", + "has_multiplexed_data", + "has_single_cell_data", + "has_spatial_data", + "includes_anndata", + "includes_cell_lines", + "includes_merged_anndata", + "includes_merged_sce", + "includes_xenografts", + "diagnoses", + "seq_units", + "modalities", + "organisms", + "technologies", + "disease_timings", + "human_readable_pi_name", + "title", + "abstract", + # counts + "sample_count", + "downloadable_sample_count", + "multiplexed_sample_count", + "unavailable_samples_count", + # timestamps + "created_at", + "updated_at", + ], +) @extend_schema(auth=False) diff --git a/api/scpca_portal/views/sample.py b/api/scpca_portal/views/sample.py index e8843e669..e3f18327a 100644 --- a/api/scpca_portal/views/sample.py +++ b/api/scpca_portal/views/sample.py @@ -1,10 +1,9 @@ -from django.contrib.postgres.fields import ArrayField from rest_framework import viewsets -from django_filters import rest_framework as filters from drf_spectacular.utils import extend_schema from rest_framework_extensions.mixins import NestedViewSetMixin +from scpca_portal import filter from scpca_portal.models import Sample from scpca_portal.serializers import ComputedFileSerializer, ProjectSerializer, SampleSerializer @@ -14,42 +13,39 @@ class SampleDetailSerializer(SampleSerializer): project = ProjectSerializer(read_only=True) -class SampleFilterSet(filters.FilterSet): - """ - Custom FilterSet to support ArrayField. - """ - - class Meta: - model = Sample - fields = [ - "scpca_id", - "project__scpca_id", - "scpca_id", - "has_cite_seq_data", - "has_bulk_rna_seq", - "has_multiplexed_data", - "has_single_cell_data", - "has_spatial_data", - "technologies", - "diagnosis", - "subdiagnosis", - "age", - "age_timing", - "sex", - "disease_timing", - "tissue_location", - "treatment", - "seq_units", - ] - - filter_overrides = { - ArrayField: { - "filter_class": filters.CharFilter, - "extra": lambda f: { - "lookup_expr": "icontains", - }, - } - } +SampleFilterSet = filter.build_auto_filterset( + Sample, + auto_fields=[ + "scpca_id", + "has_cite_seq_data", + "has_bulk_rna_seq", + "has_multiplexed_data", + "has_single_cell_data", + "has_spatial_data", + "includes_anndata", + "is_cell_line", + "is_xenograft", + "technologies", + "diagnosis", + "subdiagnosis", + "age", + "age_timing", + "sex", + "disease_timing", + "tissue_location", + "treatment", + "seq_units", + # counts + "demux_cell_count_estimate_sum", + "sample_cell_count_estimate", + # timestamps + "created_at", + "updated_at", + ], + extra_fields={ + "project__scpca_id": ["exact"], + }, +) @extend_schema(auth=False)