diff --git a/ami/base/permissions.py b/ami/base/permissions.py index 5ba3a9c68..ffae6211a 100644 --- a/ami/base/permissions.py +++ b/ami/base/permissions.py @@ -66,7 +66,7 @@ def add_object_level_permissions( # Do not return create, view permissions at object-level filtered_permissions -= {"create", "view"} permissions.update(filtered_permissions) - response_data["user_permissions"] = permissions + response_data["user_permissions"] = list(permissions) return response_data @@ -86,7 +86,7 @@ def add_collection_level_permissions(user: User | None, response_data: dict, mod if user and project and f"create_{model.__name__.lower()}" in get_perms(user, project): permissions.add("create") - response_data["user_permissions"] = permissions + response_data["user_permissions"] = list(permissions) return response_data diff --git a/ami/exports/__init__.py b/ami/exports/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ami/exports/admin.py b/ami/exports/admin.py new file mode 100644 index 000000000..91dc29cb7 --- /dev/null +++ b/ami/exports/admin.py @@ -0,0 +1,64 @@ +from django.contrib import admin +from django.http import HttpRequest + +from .models import DataExport + + +@admin.register(DataExport) +class DataExportAdmin(admin.ModelAdmin): + """ + Admin panel for managing DataExport objects. + """ + + list_display = ("id", "user", "format", "status_display", "project", "created_at", "get_job") + list_filter = ("format", "project") + search_fields = ("user__username", "format", "project__name") + readonly_fields = ("status_display", "file_url_display") + + fieldsets = ( + ( + None, + { + "fields": ("user", "format", "project", "filters"), + }, + ), + ( + "Job Info", + { + "fields": ("status_display", "file_url_display"), + "classes": ("collapse",), # This makes job-related fields collapsible in the admin panel + }, + ), + ) + + def get_queryset(self, request: HttpRequest): + """ + Optimize queryset by selecting related project and job data. + """ + return super().get_queryset(request).select_related("project", "job") + + @admin.display(description="Status") + def status_display(self, obj): + return obj.status # Calls the @property from the model + + @admin.display(description="File URL") + def file_url_display(self, obj): + return obj.file_url # Calls the @property from the model + + @admin.display(description="Job ID") + def get_job(self, obj): + """Displays the related job ID or 'No Job' if none exists.""" + return obj.job.id if obj.job else "No Job" + + @admin.action(description="Run export job") + def run_export_job(self, request: HttpRequest, queryset): + """ + Admin action to trigger the export job manually. + """ + for export in queryset: + if export.job: + export.job.enqueue() + + self.message_user(request, f"Started export job for {queryset.count()} export(s).") + + actions = [run_export_job] diff --git a/ami/exports/apps.py b/ami/exports/apps.py new file mode 100644 index 000000000..4023432d4 --- /dev/null +++ b/ami/exports/apps.py @@ -0,0 +1,9 @@ +from django.apps import AppConfig + + +class ExportsConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "ami.exports" + + def ready(self): + import ami.exports.signals # noqa: F401 diff --git a/ami/exports/base.py b/ami/exports/base.py new file mode 100644 index 000000000..389480d5e --- /dev/null +++ b/ami/exports/base.py @@ -0,0 +1,73 @@ +import logging +import os +from abc import ABC, abstractmethod + +from ami.exports.utils import apply_filters + +logger = logging.getLogger(__name__) + + +class BaseExporter(ABC): + """Base class for all data export handlers.""" + + file_format = "" # To be defined in child classes + serializer_class = None + filter_backends = [] + + def __init__(self, data_export): + self.data_export = data_export + self.job = data_export.job if hasattr(data_export, "job") else None + self.project = data_export.project + self.queryset = apply_filters( + queryset=self.get_queryset(), filters=data_export.filters, filter_backends=self.get_filter_backends() + ) + self.total_records = self.queryset.count() + if self.job: + self.job.progress.add_stage_param(self.job.job_type_key, "Number of records exported", 0) + self.job.progress.add_stage_param(self.job.job_type_key, "Total records to export", self.total_records) + self.job.save() + + @abstractmethod + def export(self): + """Perform the export process.""" + raise NotImplementedError() + + @abstractmethod + def get_queryset(self): + raise NotImplementedError() + + def get_serializer_class(self): + return self.serializer_class + + def get_filter_backends(self): + from ami.main.api.views import OccurrenceCollectionFilter + + return [OccurrenceCollectionFilter] + + def update_export_stats(self, file_temp_path=None): + """ + Updates record_count based on queryset and file size after export. + """ + # Set record count from queryset + self.data_export.record_count = self.queryset.count() + + # Check if temp file path is provided and update file size + + if file_temp_path and os.path.exists(file_temp_path): + self.data_export.file_size = os.path.getsize(file_temp_path) + + # Save the updated values + self.data_export.save() + + def update_job_progress(self, records_exported): + """ + Updates job progress and record count. + """ + if self.job: + self.job.progress.update_stage( + self.job.job_type_key, progress=round(records_exported / self.total_records, 2) + ) + self.job.progress.add_or_update_stage_param( + self.job.job_type_key, "Number of records exported", records_exported + ) + self.job.save() diff --git a/ami/exports/format_types.py b/ami/exports/format_types.py new file mode 100644 index 000000000..1f5af70bd --- /dev/null +++ b/ami/exports/format_types.py @@ -0,0 +1,159 @@ +import csv +import json +import logging +import tempfile + +from django.core.serializers.json import DjangoJSONEncoder +from rest_framework import serializers + +from ami.exports.base import BaseExporter +from ami.exports.utils import get_data_in_batches +from ami.main.models import Occurrence + +logger = logging.getLogger(__name__) + + +def get_export_serializer(): + from ami.main.api.serializers import OccurrenceSerializer + + class OccurrenceExportSerializer(OccurrenceSerializer): + detection_images = serializers.SerializerMethodField() + + def get_detection_images(self, obj: Occurrence): + """Convert the generator field to a list before serialization""" + if hasattr(obj, "detection_images") and callable(obj.detection_images): + return list(obj.detection_images()) # Convert generator to list + return [] + + def get_permissions(self, instance_data): + return instance_data + + def to_representation(self, instance): + return serializers.HyperlinkedModelSerializer.to_representation(self, instance) + + return OccurrenceExportSerializer + + +class JSONExporter(BaseExporter): + """Handles JSON export of occurrences.""" + + file_format = "json" + + def get_serializer_class(self): + return get_export_serializer() + + def get_queryset(self): + return ( + Occurrence.objects.filter(project=self.project) + .select_related( + "determination", + "deployment", + "event", + ) + .with_timestamps() # type: ignore[union-attr] Custom queryset method + .with_detections_count() + .with_identifications() + ) + + def export(self): + """Exports occurrences to JSON format.""" + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode="w", encoding="utf-8") + with open(temp_file.name, "w", encoding="utf-8") as f: + first = True + f.write("[") + records_exported = 0 + for i, batch in enumerate(get_data_in_batches(self.queryset, self.get_serializer_class())): + json_data = json.dumps(batch, cls=DjangoJSONEncoder) + json_data = json_data[1:-1] # remove [ and ] from json string + f.write(",\n" if not first else "") + f.write(json_data) + first = False + records_exported += len(batch) + self.update_job_progress(records_exported) + f.write("]") + + self.update_export_stats(file_temp_path=temp_file.name) + return temp_file.name # Return file path + + +class OccurrenceTabularSerializer(serializers.ModelSerializer): + """Serializer to format occurrences for tabular data export.""" + + event_id = serializers.IntegerField(source="event.id", allow_null=True) + event_name = serializers.CharField(source="event.name", allow_null=True) + deployment_id = serializers.IntegerField(source="deployment.id", allow_null=True) + deployment_name = serializers.CharField(source="deployment.name", allow_null=True) + project_id = serializers.IntegerField(source="project.id", allow_null=True) + project_name = serializers.CharField(source="project.name", allow_null=True) + + determination_id = serializers.IntegerField(source="determination.id", allow_null=True) + determination_name = serializers.CharField(source="determination.name", allow_null=True) + determination_score = serializers.FloatField(allow_null=True) + verification_status = serializers.SerializerMethodField() + + class Meta: + model = Occurrence + fields = [ + "id", + "event_id", + "event_name", + "deployment_id", + "deployment_name", + "project_id", + "project_name", + "determination_id", + "determination_name", + "determination_score", + "verification_status", + "detections_count", + "first_appearance_timestamp", + "last_appearance_timestamp", + "duration", + ] + + def get_verification_status(self, obj): + """ + Returns 'Verified' if the occurrence has identifications, otherwise 'Not verified'. + """ + return "Verified" if obj.identifications.exists() else "Not verified" + + +class CSVExporter(BaseExporter): + """Handles CSV export of occurrences.""" + + file_format = "csv" + + serializer_class = OccurrenceTabularSerializer + + def get_queryset(self): + return ( + Occurrence.objects.filter(project=self.project) + .select_related( + "determination", + "deployment", + "event", + ) + .with_timestamps() # type: ignore[union-attr] Custom queryset method + .with_detections_count() + .with_identifications() + ) + + def export(self): + """Exports occurrences to CSV format.""" + + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", newline="", encoding="utf-8") + + # Extract field names dynamically from the serializer + serializer = self.serializer_class() + field_names = list(serializer.fields.keys()) + records_exported = 0 + with open(temp_file.name, "w", newline="", encoding="utf-8") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=field_names) + writer.writeheader() + + for i, batch in enumerate(get_data_in_batches(self.queryset, self.serializer_class)): + writer.writerows(batch) + records_exported += len(batch) + self.update_job_progress(records_exported) + self.update_export_stats(file_temp_path=temp_file.name) + return temp_file.name # Return the file path diff --git a/ami/exports/migrations/0001_initial.py b/ami/exports/migrations/0001_initial.py new file mode 100644 index 000000000..90be52897 --- /dev/null +++ b/ami/exports/migrations/0001_initial.py @@ -0,0 +1,57 @@ +# Generated by Django 4.2.10 on 2025-04-02 20:12 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + ("main", "0058_alter_project_options"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="DataExport", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "format", + models.CharField( + choices=[ + ("occurrences_simple_json", "occurrences_simple_json"), + ("occurrences_simple_csv", "occurrences_simple_csv"), + ], + max_length=255, + ), + ), + ("filters", models.JSONField(blank=True, null=True)), + ("filters_display", models.JSONField(blank=True, null=True)), + ("file_url", models.URLField(blank=True, null=True)), + ("record_count", models.PositiveIntegerField(default=0)), + ("file_size", models.PositiveBigIntegerField(default=0)), + ( + "project", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, related_name="exports", to="main.project" + ), + ), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="exports", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/ami/exports/migrations/__init__.py b/ami/exports/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ami/exports/models.py b/ami/exports/models.py new file mode 100644 index 000000000..f0bcff353 --- /dev/null +++ b/ami/exports/models.py @@ -0,0 +1,139 @@ +import logging + +from django.conf import settings +from django.core.files.storage import default_storage +from django.db import models +from django.utils.text import slugify +from rest_framework.request import Request + +from ami.base.models import BaseModel +from ami.main.models import Project +from ami.users.models import User + +logger = logging.getLogger(__name__) + + +def get_export_choices(): + from ami.exports.registry import ExportRegistry + + """Dynamically fetch available export formats from the ExportRegistry.""" + return [(key, key) for key in ExportRegistry.get_supported_formats()] + + +class DataExport(BaseModel): + """A model to track data exports""" + + user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="exports") + project = models.ForeignKey(Project, on_delete=models.CASCADE, related_name="exports") + format = models.CharField(max_length=255, choices=get_export_choices()) + filters = models.JSONField(null=True, blank=True) + filters_display = models.JSONField(null=True, blank=True) + file_url = models.URLField(blank=True, null=True) + # Number of exported records. + record_count = models.PositiveIntegerField(default=0) + # Size of the exported file in bytes. + file_size = models.PositiveBigIntegerField(default=0) + + def get_filters_display(self): + """ + Precompute a display-friendly version of filters. + """ + from django.apps import apps + + related_models = { + "collection": "main.SourceImageCollection", + "taxa_list": "main.TaxaList", + } + filters = self.filters or {} + filters_display = {} + + for key, value in filters.items(): + if key in related_models: + model_path = related_models[key] + try: + Model = apps.get_model(model_path) + instance = Model.objects.get(pk=value) + filters_display[key] = {"id": value, "name": str(instance)} + except Model.DoesNotExist: + filters_display[key] = {"id": value, "name": f"{model_path} with id {value} not found"} + except Exception as e: + filters_display[key] = {"id": value, "name": f"Error: {str(e)}"} + else: + filters_display[key] = value + + return filters_display + + def generate_filename(self): + """Generates a slugified filename using project name and export ID.""" + from ami.exports.registry import ExportRegistry + + extension = ExportRegistry.get_exporter(self.format).file_format + project_slug = slugify(self.project.name) # Convert project name to a slug + return f"{project_slug}_export-{self.pk}.{extension}" + + def save_export_file(self, file_temp_path): + """ + Saves the exported file to the default storage. + """ + # Generate file path in the 'exports' directory + file_path = f"exports/{self.generate_filename()}" + + # Save the file to the specified path in default storage + with open(file_temp_path, "rb") as f: + default_storage.save(file_path, f) + file_url = f"{settings.MEDIA_URL}{file_path}" + return file_url + + def get_exporter(self): + """ + Initialize and return an Exporter instance based on the requested format. + + The init method of the Exporter class is called here, + which can trigger a large query, so do this only once. + """ + cache_key = "_exporter" + if hasattr(self, cache_key): + return getattr(self, cache_key) + + from ami.exports.registry import ExportRegistry + + export_format = self.format + ExportClass = ExportRegistry.get_exporter(export_format) + if not ExportClass: + raise ValueError("Invalid export format") + logger.debug(f"Exporter class {ExportClass}") + exporter = ExportClass(self) + setattr(self, cache_key, exporter) + return exporter + + def update_record_count(self): + """ + Calculate and save the total number of records in the export's queryset. + """ + exporter = self.get_exporter() + self.record_count = exporter.total_records + self.save(update_fields=["record_count"]) + return self.record_count + + def run_export(self): + logger.info(f"Starting export for format: {self.format}") + exporter = self.get_exporter() + file_temp_path = exporter.export() + file_url = self.save_export_file(file_temp_path) + self.file_url = file_url + self.save(update_fields=["file_url"]) + return file_url + + def get_absolute_url(self, request: Request | None) -> str | None: + """Returns the full URL of the file.""" + if not self.file_url: + return None + if not request: + return self.file_url + else: + return request.build_absolute_uri(self.file_url) + + def save(self, *args, **kwargs): + # Update filters_display before saving + self.filters_display = self.get_filters_display() + super().save(*args, **kwargs) diff --git a/ami/exports/registry.py b/ami/exports/registry.py new file mode 100644 index 000000000..7af239fd1 --- /dev/null +++ b/ami/exports/registry.py @@ -0,0 +1,29 @@ +import ami.exports.format_types as format_types + + +class ExportRegistry: + _registry = {} + + @classmethod + def register(cls, format_type): + """Decorator to register an export format.""" + + def decorator(exporter_class): + cls._registry[format_type] = exporter_class + return exporter_class + + return decorator + + @classmethod + def get_exporter(cls, format_type): + """Retrieve an exporter class based on format type.""" + return cls._registry.get(format_type, None) + + @classmethod + def get_supported_formats(cls): + """Return a list of registered formats.""" + return list(cls._registry.keys()) + + +ExportRegistry.register("occurrences_simple_json")(format_types.JSONExporter) +ExportRegistry.register("occurrences_simple_csv")(format_types.CSVExporter) diff --git a/ami/exports/serializers.py b/ami/exports/serializers.py new file mode 100644 index 000000000..e16f63025 --- /dev/null +++ b/ami/exports/serializers.py @@ -0,0 +1,74 @@ +from django.template.defaultfilters import filesizeformat +from rest_framework import serializers + +from ami.base.serializers import DefaultSerializer +from ami.exports.registry import ExportRegistry +from ami.jobs.models import Job +from ami.jobs.serializers import JobListSerializer +from ami.main.api.serializers import UserNestedSerializer +from ami.main.models import Project + +from .models import DataExport + + +class DataExportJobNestedSerializer(JobListSerializer): + """ + Job Nested serializer for DataExport. + """ + + class Meta: + model = Job + fields = [ + "id", + "name", + "project", + "progress", + "result", + ] + + +class DataExportSerializer(DefaultSerializer): + """ + Serializer for DataExport + """ + + job = DataExportJobNestedSerializer(read_only=True) # Nested job serializer + user = UserNestedSerializer(read_only=True) + project = serializers.PrimaryKeyRelatedField(queryset=Project.objects.all(), write_only=True) + file_url = serializers.SerializerMethodField() + file_size_display = serializers.SerializerMethodField() + + class Meta: + model = DataExport + fields = [ + "id", + "user", + "project", + "format", + "filters", + "filters_display", + "job", + "file_url", + "record_count", + "file_size", + "file_size_display", + "created_at", + "updated_at", + ] + + def validate_format(self, value): + supported_formats = ExportRegistry.get_supported_formats() + if value not in supported_formats: + raise serializers.ValidationError(f"Invalid format. Supported formats are: {supported_formats}") + return value + + def get_file_url(self, obj): + return obj.get_absolute_url(request=self.context.get("request")) + + def get_file_size_display(self, obj): + """ + Converts file size from bytes to a more readable format. + """ + if not obj.file_size: + return None + return filesizeformat(obj.file_size) diff --git a/ami/exports/signals.py b/ami/exports/signals.py new file mode 100644 index 000000000..7aca8d47e --- /dev/null +++ b/ami/exports/signals.py @@ -0,0 +1,29 @@ +import logging + +from django.conf import settings +from django.core.files.storage import default_storage +from django.db.models.signals import pre_delete +from django.dispatch import receiver + +from .models import DataExport + +logger = logging.getLogger(__name__) + + +@receiver(pre_delete, sender=DataExport) +def delete_exported_file(sender, instance, **kwargs): + """ + Deletes the exported file when the DataExport instance is deleted. + """ + + file_url = instance.file_url + + if file_url: + try: + relative_path = file_url.replace(settings.MEDIA_URL, "").lstrip("/") + if default_storage.exists(relative_path): + default_storage.delete(relative_path) + logger.info(f"Deleted export file: {relative_path}") + + except Exception as e: + logger.error(f"Error deleting export file {relative_path}: {e}") diff --git a/ami/exports/tests.py b/ami/exports/tests.py new file mode 100644 index 000000000..1a7a5532a --- /dev/null +++ b/ami/exports/tests.py @@ -0,0 +1,120 @@ +import csv +import json + +from django.core.files.base import ContentFile +from django.core.files.storage import default_storage +from django.test import TestCase +from rest_framework.test import APIClient + +from ami.exports.models import DataExport +from ami.main.models import Deployment, Occurrence, Project, SourceImageCollection +from ami.tests.fixtures.main import create_captures +from ami.users.models import User + + +class DataExportTest(TestCase): + def setUp(self): + self.user = User.objects.create_user(email="testuser@insectai.org", is_superuser=True, is_staff=True) + self.project = Project.objects.create(name="Test Project") + self.client = APIClient() + self.client.force_authenticate(user=self.user) + # Create test deployment + self.deployment = Deployment.objects.create(name="Test Deployment", project=self.project) + # Create captures for the deployment + create_captures(deployment=self.deployment, num_nights=2, images_per_night=10, interval_minutes=1) + # Create a collection using the provided method + self.collection = self._create_collection() + # Define export formats + self.export_formats = ["occurrences_simple_csv", "occurrences_simple_json"] + + def _create_export_with_file(self, format_type): + filename = f"exports/test_export_file_{format_type}.json" + default_storage.save(filename, ContentFile(b"Dummy content")) + + export = DataExport.objects.create( + user=self.user, + project=self.project, + format=format_type, + file_url=filename, + ) + + return export, filename + + def test_file_is_deleted_when_export_is_deleted(self): + for format_type in self.export_formats: + with self.subTest(format=format_type): + export, filename = self._create_export_with_file(format_type) + + self.assertTrue(default_storage.exists(filename)) + + response = self.client.delete(f"/api/v2/exports/{export.pk}/") + self.assertEqual(response.status_code, 204) + + self.assertFalse(default_storage.exists(filename)) + + def _create_collection(self): + """Create a SourceImageCollection from deployment captures.""" + images = self.deployment.captures.all() + + # Create the collection + collection = SourceImageCollection.objects.create( + name="Test Manual Source Image Collection", + project=self.project, + method="manual", + kwargs={"image_ids": [image.pk for image in images]}, + ) + collection.save() + + # Populate the collection sample + collection.populate_sample() + return collection + + def run_and_validate_export(self, format_type): + """Run export and validate record count in the exported file.""" + # Create a DataExport instance + data_export = DataExport.objects.create( + user=self.user, + project=self.project, + format=format_type, + filters={"collection": self.collection.pk}, + job=None, + ) + + # Run export and get the file URL + file_url = data_export.run_export() + + # Ensure the file is generated + self.assertIsNotNone(file_url) + file_path = file_url.replace("/media/", "") + self.assertTrue(default_storage.exists(file_path)) + + # Read and validate the exported data + with default_storage.open(file_path, "r") as f: + if format_type == "occurrences_simple_csv": + self.validate_csv_records(f) + elif format_type == "occurrences_simple_json": + self.validate_json_records(f) + + # Clean up the exported file after the test + default_storage.delete(file_path) + + def validate_csv_records(self, file): + """Validate record count in CSV.""" + csv_reader = csv.DictReader(file) + row_count = sum(1 for row in csv_reader) + expected_count = Occurrence.objects.filter(detections__source_image__collections=self.collection).count() + self.assertEqual(row_count, expected_count) + + def validate_json_records(self, file): + """Validate record count in JSON.""" + data = json.load(file) + expected_count = Occurrence.objects.filter(detections__source_image__collections=self.collection).count() + self.assertEqual(len(data), expected_count) + + def test_csv_export_record_count(self): + """Test CSV export record count.""" + self.run_and_validate_export("occurrences_simple_csv") + + def test_json_export_record_count(self): + """Test JSON export record count.""" + self.run_and_validate_export("occurrences_simple_json") diff --git a/ami/exports/utils.py b/ami/exports/utils.py new file mode 100644 index 000000000..e59454219 --- /dev/null +++ b/ami/exports/utils.py @@ -0,0 +1,105 @@ +import logging + +from django.conf import settings +from django.db import models +from django.test import RequestFactory +from rest_framework import serializers +from rest_framework.request import Request +from rest_framework.versioning import NamespaceVersioning + +logger = logging.getLogger(__name__) + + +def generate_fake_request( + path: str = "/api/v2/occurrences/", + method: str = "GET", + query_params: dict = None, + headers: dict = None, +) -> Request: + """ + Generate a fake DRF request object to mimic an actual API request. + + Args: + path (str): The API endpoint path (default: occurrences list view). + method (str): The HTTP method (default: GET). + query_params (dict, optional): Query parameters to include in the request. + headers (dict, optional): Additional HTTP headers. + + Returns: + Request: A DRF request object that mimics a real API request. + """ + + from urllib.parse import urlencode + + factory = RequestFactory() + + # Construct the full URL with query parameters + full_path = f"{path}?{urlencode(query_params)}" if query_params else path + + # Create the base request + request_method = getattr(factory, method.lower(), factory.get) + raw_request = request_method(full_path) + + # Set HTTP Host + raw_request.META["HTTP_HOST"] = getattr(settings, "EXTERNAL_HOSTNAME", "localhost") + + # Add additional headers if provided + if headers: + for key, value in headers.items(): + raw_request.META[f"HTTP_{key.upper().replace('-', '_')}"] = value + + # Wrap in DRF's Request object + fake_request = Request(raw_request) + + # Set versioning details + fake_request.version = "api" + fake_request.versioning_scheme = NamespaceVersioning() + + return fake_request + + +def apply_filters(queryset, filters, filter_backends): + """ + Apply filtering backends to the given queryset using the provided filter query params. + """ + request = generate_fake_request(query_params=filters) + logger.debug(f"Queryset count before filtering : {queryset.count()}") + logger.debug(f"Filter values : {filters}") + logger.debug(f"Filter backends : {filter_backends}") + for backend in filter_backends: + queryset = backend().filter_queryset(request, queryset, None) # `view` is None since we are not using ViewSet + logger.debug(f"Queryset count after filtering : {queryset.count()}") + return queryset + + +def get_data_in_batches(QuerySet: models.QuerySet, Serializer: type[serializers.Serializer], batch_size=100): + """ + Yield batches of serialized data from a queryset efficiently. + """ + items = QuerySet.iterator(chunk_size=batch_size) # Efficient iteration to avoid memory issues + batch = [] + + fake_request = generate_fake_request() + for i, item in enumerate(items): + try: + serializer = Serializer( + item, + context={ + "request": fake_request, + }, + ) + + item_data = serializer.data + + batch.append(item_data) + + # Yield batch once it reaches batch_size + if len(batch) >= batch_size: + yield batch + batch = [] # Reset batch + except Exception as e: + logger.warning(f"Error processing occurrence {item.id}: {str(e)}") + raise e + + if len(batch): + yield batch # yield the last batch if total number of records not divisible by batch_size diff --git a/ami/exports/views.py b/ami/exports/views.py new file mode 100644 index 000000000..0c8cc9a16 --- /dev/null +++ b/ami/exports/views.py @@ -0,0 +1,78 @@ +from rest_framework import status +from rest_framework.response import Response + +from ami.base.views import ProjectMixin +from ami.exports.serializers import DataExportSerializer +from ami.jobs.models import DataExportJob, Job, SourceImageCollection +from ami.main.api.views import DefaultViewSet + +from .models import DataExport + + +class ExportViewSet(DefaultViewSet, ProjectMixin): + """ + API endpoint for exporting occurrences. + """ + + queryset = DataExport.objects.all() + serializer_class = DataExportSerializer + ordering_fields = ["id", "format", "file_size", "created_at", "updated_at"] + + def get_queryset(self): + queryset = super().get_queryset().select_related("job") + project = self.get_active_project() + if project: + queryset = queryset.filter(project=project) + return queryset + + def create(self, request, *args, **kwargs): + """ + Create a new DataExport entry and trigger the export job. + """ + + # Use serializer for validation + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + + validated_data = serializer.validated_data + format_type = validated_data["format"] + filters = validated_data.get("filters", {}) + project = validated_data["project"] + + # Check optional collection filter + collection = None + collection_id = filters.get("collection_id") + if collection_id: + try: + collection = SourceImageCollection.objects.get(pk=collection_id) + except SourceImageCollection.DoesNotExist: + return Response( + {"error": "Collection does not exist."}, + status=status.HTTP_400_BAD_REQUEST, + ) + if collection.project != project: + return Response( + {"error": "Collection does not belong to the selected project."}, + status=status.HTTP_400_BAD_REQUEST, + ) + + # Create DataExport object + data_export = DataExport.objects.create( + user=request.user, + format=format_type, + filters=filters, + project=project, + ) + data_export.update_record_count() + + job_name = f"Export occurrences{f' for collection {collection.pk}' if collection else ''}" + job = Job.objects.create( + name=job_name, + project=project, + job_type_key=DataExportJob.key, + data_export=data_export, + source_image_collection=collection, + ) + job.enqueue() + + return Response(self.get_serializer(data_export).data, status=status.HTTP_201_CREATED) diff --git a/ami/jobs/admin.py b/ami/jobs/admin.py index 10cbdba89..b5c921502 100644 --- a/ami/jobs/admin.py +++ b/ami/jobs/admin.py @@ -1,6 +1,6 @@ from django.contrib import admin from django.db.models.query import QuerySet -from django.http.request import HttpRequest +from django.http import HttpRequest from ami.main.admin import AdminBase diff --git a/ami/jobs/migrations/0016_job_data_export_job_params_alter_job_job_type_key.py b/ami/jobs/migrations/0016_job_data_export_job_params_alter_job_job_type_key.py new file mode 100644 index 000000000..17eac7fe1 --- /dev/null +++ b/ami/jobs/migrations/0016_job_data_export_job_params_alter_job_job_type_key.py @@ -0,0 +1,46 @@ +# Generated by Django 4.2.10 on 2025-04-02 20:12 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("exports", "0001_initial"), + ("jobs", "0015_merge_20250117_2100"), + ] + + operations = [ + migrations.AddField( + model_name="job", + name="data_export", + field=models.OneToOneField( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="job", + to="exports.dataexport", + ), + ), + migrations.AddField( + model_name="job", + name="params", + field=models.JSONField(blank=True, null=True), + ), + migrations.AlterField( + model_name="job", + name="job_type_key", + field=models.CharField( + choices=[ + ("ml", "ML pipeline"), + ("populate_captures_collection", "Populate captures collection"), + ("data_storage_sync", "Data storage sync"), + ("unknown", "Unknown"), + ("data_export", "Data Export"), + ], + default="unknown", + max_length=255, + verbose_name="Job Type", + ), + ), + ] diff --git a/ami/jobs/models.py b/ami/jobs/models.py index c5853689d..1166f2c74 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -603,6 +603,42 @@ def run(cls, job: "Job"): job.save() +class DataExportJob(JobType): + """ + Job type to handle Project data exports + """ + + name = "Data Export" + key = "data_export" + + @classmethod + def run(cls, job: "Job"): + """ + Run the export job asynchronously with format selection (CSV, JSON, Darwin Core). + """ + logger.info("Job started: Exporting occurrences") + + # Add progress tracking + job.progress.add_stage("Exporting data", cls.key) + job.update_status(JobState.STARTED) + job.started_at = datetime.datetime.now() + job.finished_at = None + job.save() + + job.logger.info(f"Starting export for project {job.project}") + + file_url = job.data_export.run_export() + + job.logger.info(f"Export completed: {file_url}") + job.logger.info(f"File uploaded to Project Storage: {file_url}") + # Finalize Job + stage = job.progress.add_stage("Uploading snapshot") + job.progress.add_stage_param(stage.key, "File URL", f"{file_url}") + job.progress.update_stage(stage.key, status=JobState.SUCCESS, progress=1) + job.finished_at = datetime.datetime.now() + job.update_status(JobState.SUCCESS, save=True) + + class UnknownJobType(JobType): name = "Unknown" key = "unknown" @@ -612,7 +648,7 @@ def run(cls, job: "Job"): raise ValueError(f"Unknown job type '{job.job_type()}'") -VALID_JOB_TYPES = [MLJob, SourceImageCollectionPopulateJob, DataStorageSyncJob, UnknownJobType] +VALID_JOB_TYPES = [MLJob, SourceImageCollectionPopulateJob, DataStorageSyncJob, UnknownJobType, DataExportJob] def get_job_type_by_key(key: str) -> type[JobType] | None: @@ -653,6 +689,7 @@ class Job(BaseModel): status = models.CharField(max_length=255, default=JobState.CREATED.name, choices=JobState.choices()) progress: JobProgress = SchemaField(JobProgress, default=default_job_progress()) logs: JobLogs = SchemaField(JobLogs, default=JobLogs()) + params = models.JSONField(null=True, blank=True) result = models.JSONField(null=True, blank=True) task_id = models.CharField(max_length=255, null=True, blank=True) delay = models.IntegerField("Delay in seconds", default=0, help_text="Delay before running the job") @@ -690,6 +727,13 @@ class Job(BaseModel): blank=True, related_name="jobs", ) + data_export = models.OneToOneField( + "exports.DataExport", + on_delete=models.CASCADE, # If DataExport is deleted, delete the Job + null=True, + blank=True, + related_name="job", + ) pipeline = models.ForeignKey( Pipeline, on_delete=models.SET_NULL, diff --git a/ami/jobs/serializers.py b/ami/jobs/serializers.py index f1fef491d..4ac8e4e90 100644 --- a/ami/jobs/serializers.py +++ b/ami/jobs/serializers.py @@ -1,6 +1,7 @@ from django_pydantic_field.rest_framework import SchemaField from rest_framework import serializers +from ami.exports.models import DataExport from ami.main.api.serializers import ( DefaultSerializer, DeploymentNestedSerializer, @@ -24,6 +25,14 @@ class Meta: ] +class DataExportNestedSerializer(serializers.ModelSerializer): + file_url = serializers.URLField(read_only=True) + + class Meta: + model = DataExport + fields = ["id", "user", "project", "format", "filters", "file_url"] + + class JobTypeSerializer(serializers.Serializer): name = serializers.CharField(read_only=True) key = serializers.SlugField(read_only=True) @@ -36,6 +45,7 @@ class JobListSerializer(DefaultSerializer): pipeline = PipelineNestedSerializer(read_only=True) source_image_collection = SourceImageCollectionNestedSerializer(read_only=True) source_image_single = SourceImageNestedSerializer(read_only=True) + data_export = DataExportNestedSerializer(read_only=True) progress = SchemaField(schema=JobProgress, read_only=True) logs = SchemaField(schema=JobLogs, read_only=True) job_type = JobTypeSerializer(read_only=True) @@ -116,6 +126,7 @@ class Meta: "logs", "job_type", "job_type_key", + "data_export", # "duration", # "duration_label", # "progress_label", diff --git a/ami/main/api/serializers.py b/ami/main/api/serializers.py index 6ba3567a5..6297c8590 100644 --- a/ami/main/api/serializers.py +++ b/ami/main/api/serializers.py @@ -1141,7 +1141,7 @@ def get_permissions(self, instance, instance_data): permissions = set() if instance.user == user or ProjectManager.has_role(user, project): permissions.add("delete") - instance_data["user_permissions"] = permissions + instance_data["user_permissions"] = list(permissions) return instance_data class Meta: @@ -1175,7 +1175,7 @@ def get_permissions(self, instance, instance_data): # then add update permission to response permissions.add("update") - instance_data["user_permissions"] = permissions + instance_data["user_permissions"] = list(permissions) return instance_data class Meta: diff --git a/ami/main/api/views.py b/ami/main/api/views.py index fc82f5236..2c21d757c 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -126,7 +126,7 @@ def create(self, request, *args, **kwargs): serializer.is_valid(raise_exception=True) # Create instance but do not save - instance = serializer.Meta.model(**serializer.validated_data) + instance = serializer.Meta.model(**serializer.validated_data) # type: ignore self.check_object_permissions(request, instance) self.perform_create(serializer) return Response(serializer.data, status=status.HTTP_201_CREATED) @@ -657,7 +657,7 @@ def get_queryset(self) -> QuerySet: project = self.get_active_project() if project: query_set = query_set.filter(project=project) - queryset = query_set.with_occurrences_count( + queryset = query_set.with_occurrences_count( # type: ignore classification_threshold=classification_threshold ).with_taxa_count( # type: ignore classification_threshold=classification_threshold @@ -1531,7 +1531,7 @@ def perform_create(self, serializer): Set the user to the current user. """ # Get an instance for the model without saving - obj = serializer.Meta.model(**serializer.validated_data, user=self.request.user) + obj = serializer.Meta.model(**serializer.validated_data, user=self.request.user) # type: ignore # Check permissions before saving self.check_object_permissions(self.request, obj) diff --git a/config/api_router.py b/config/api_router.py index 87b7647ee..4d4d593f3 100644 --- a/config/api_router.py +++ b/config/api_router.py @@ -4,6 +4,7 @@ from djoser.views import UserViewSet from rest_framework.routers import DefaultRouter, SimpleRouter +from ami.exports import views as export_views from ami.jobs import views as job_views from ami.labelstudio import views as labelstudio_views from ami.main.api import views @@ -21,6 +22,7 @@ router.register(r"deployments/sites", views.SiteViewSet) router.register(r"deployments", views.DeploymentViewSet) router.register(r"events", views.EventViewSet) +router.register(r"exports", export_views.ExportViewSet) router.register(r"captures/collections", views.SourceImageCollectionViewSet) router.register(r"captures/upload", views.SourceImageUploadViewSet) router.register(r"captures", views.SourceImageViewSet) @@ -36,6 +38,7 @@ router.register(r"identifications", views.IdentificationViewSet) router.register(r"jobs", job_views.JobViewSet) router.register(r"pages", views.PageViewSet) +router.register(r"exports", export_views.ExportViewSet) router.register( r"labelstudio/captures", labelstudio_views.LabelStudioSourceImageViewSet, basename="labelstudio-captures" ) diff --git a/config/settings/base.py b/config/settings/base.py index d9e3ce82b..79a063f81 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -101,6 +101,7 @@ "ami.jobs", "ami.ml", "ami.labelstudio", + "ami.exports", ] # https://docs.djangoproject.com/en/dev/ref/settings/#installed-apps INSTALLED_APPS = DJANGO_APPS + THIRD_PARTY_APPS + LOCAL_APPS diff --git a/ui/src/data-services/models/job.ts b/ui/src/data-services/models/job.ts index 738e6a323..e7758eaa2 100644 --- a/ui/src/data-services/models/job.ts +++ b/ui/src/data-services/models/job.ts @@ -19,6 +19,7 @@ export const SERVER_JOB_TYPES = [ 'ml', 'data_storage_sync', 'populate_captures_collection', + 'data_export', 'unknown', ] as const @@ -175,6 +176,7 @@ export class Job { ml: 'ML pipeline', data_storage_sync: 'Data storage sync', populate_captures_collection: 'Populate captures collection', + data_export: 'Data export', unknown: 'Unknown', }[key]