diff --git a/ami/base/views.py b/ami/base/views.py index 3291a9b10..eac56d513 100644 --- a/ami/base/views.py +++ b/ami/base/views.py @@ -22,6 +22,8 @@ class ProjectMixin: def get_active_project(self) -> Project: from ami.base.serializers import SingleParamSerializer + param = "project_id" + project_id = None # Extract from URL `/projects/` is in the url path if "/projects/" in self.request.path: @@ -29,13 +31,14 @@ def get_active_project(self) -> Project: # If not in URL, try query parameters if not project_id: - if self.require_project: - project_id = SingleParamSerializer[int].clean( - param_name="project_id", - field=serializers.IntegerField(required=True, min_value=0), - data=self.request.query_params, - ) - else: - project_id = self.request.query_params.get("project_id") # No validation + # Look for project_id in GET query parameters or POST data + # POST data returns a list of ints, but QueryDict.get() returns a single value + project_id = self.request.query_params.get(param) or self.request.data.get(param) + + project_id = SingleParamSerializer[int].clean( + param_name=param, + field=serializers.IntegerField(required=self.require_project, min_value=0), + data={param: project_id} if project_id else {}, + ) return get_object_or_404(Project, id=project_id) if project_id else None diff --git a/ami/main/api/serializers.py b/ami/main/api/serializers.py index c9e7d4145..61797f9f8 100644 --- a/ami/main/api/serializers.py +++ b/ami/main/api/serializers.py @@ -1,15 +1,14 @@ import datetime -from django.core.exceptions import ValidationError as DjangoValidationError from django.db.models import QuerySet from guardian.shortcuts import get_perms from rest_framework import serializers from rest_framework.request import Request from ami.base.fields import DateStringField -from ami.base.serializers import DefaultSerializer, MinimalNestedModelSerializer, get_current_user, reverse_with_params +from ami.base.serializers import DefaultSerializer, MinimalNestedModelSerializer, reverse_with_params from ami.jobs.models import Job -from ami.main.models import Tag, create_source_image_from_upload +from ami.main.models import Tag from ami.ml.models import Algorithm, Pipeline from ami.ml.serializers import AlgorithmSerializer, PipelineNestedSerializer from ami.users.models import User @@ -33,7 +32,6 @@ SourceImageUpload, TaxaList, Taxon, - validate_filename_timestamp, ) @@ -1085,30 +1083,6 @@ class Meta: "created_at", ] - def create(self, validated_data): - # Add the user to the validated data - request = self.context.get("request") - user = get_current_user(request) - # @TODO IMPORTANT ensure current user is a member of the deployment's project - obj = SourceImageUpload.objects.create(user=user, **validated_data) - source_image = create_source_image_from_upload( - obj.image, - obj.deployment, - request, - ) - if source_image is not None: - obj.source_image = source_image # type: ignore - obj.save() - return obj - - def validate_image(self, value): - # Ensure that image filename contains a timestamp - try: - validate_filename_timestamp(value.name) - except DjangoValidationError as e: - raise serializers.ValidationError(str(e)) - return value - class SourceImageCollectionCommonKwargsSerializer(serializers.Serializer): # The most common kwargs for the sampling methods diff --git a/ami/main/api/views.py b/ami/main/api/views.py index 4a6fedb97..3abdb3a47 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -760,6 +760,7 @@ class SourceImageUploadViewSet(DefaultViewSet, ProjectMixin): serializer_class = SourceImageUploadSerializer permission_classes = [SourceImageUploadCRUDPermission] + require_project = True def get_queryset(self) -> QuerySet: # Only allow users to see their own uploads @@ -772,6 +773,35 @@ def get_queryset(self) -> QuerySet: # This is the maximum limit for manually uploaded captures pagination_class.default_limit = 20 + def perform_create(self, serializer): + """ + Save the SourceImageUpload with the current user and create the associated SourceImage. + """ + from ami.base.serializers import get_current_user + from ami.main.models import create_source_image_from_upload + + # Get current user from request + user = get_current_user(self.request) + project = self.get_active_project() + + # Create the SourceImageUpload object with the user + obj = serializer.save(user=user) + + # Get process_now flag from project feature flags + process_now = project.feature_flags.auto_processs_manual_uploads + + # Create source image from the upload + source_image = create_source_image_from_upload( + image=obj.image, + deployment=obj.deployment, + request=self.request, + process_now=process_now, + ) + + # Update the source_image reference and save + obj.source_image = source_image + obj.save() + class DetectionViewSet(DefaultViewSet, ProjectMixin): """ diff --git a/ami/main/migrations/0066_alter_project_feature_flags_and_more.py b/ami/main/migrations/0066_alter_project_feature_flags_and_more.py new file mode 100644 index 000000000..93cd5bc75 --- /dev/null +++ b/ami/main/migrations/0066_alter_project_feature_flags_and_more.py @@ -0,0 +1,29 @@ +# Generated by Django 4.2.10 on 2025-08-08 21:53 + +import ami.main.models +from django.db import migrations, models +import django_pydantic_field.fields + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0065_project_default_filters_exclude_taxa_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="project", + name="feature_flags", + field=django_pydantic_field.fields.PydanticSchemaField( + blank=True, + config=None, + default={"auto_processs_manual_uploads": False, "tags": False}, + schema=ami.main.models.ProjectFeatureFlags, + ), + ), + migrations.AlterField( + model_name="sourceimageupload", + name="image", + field=models.ImageField(upload_to=ami.main.models.upload_to_with_deployment), + ), + ] diff --git a/ami/main/models.py b/ami/main/models.py index 3172499cb..cd75ce571 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -1,14 +1,15 @@ import collections import datetime import functools -import hashlib import logging import textwrap import time import typing import urllib.parse +from io import BytesIO from typing import Final, final # noqa: F401 +import PIL.Image import pydantic from django.apps import apps from django.conf import settings @@ -31,11 +32,12 @@ from ami.main import charts from ami.main.models_future.projects import ProjectSettingsMixin from ami.users.models import User +from ami.utils.media import calculate_file_checksum, extract_timestamp from ami.utils.schemas import OrderedEnum if typing.TYPE_CHECKING: from ami.jobs.models import Job - from ami.ml.models import ProcessingService + from ami.ml.models import Pipeline, ProcessingService logger = logging.getLogger(__name__) @@ -120,12 +122,16 @@ def get_or_create_default_deployment( def get_or_create_default_collection(project: "Project") -> "SourceImageCollection": - """Create a default collection for a project for all images, updated dynamically.""" + """ + Create a default collection for a project for all images. + + @TODO Consider ways to update this collection automatically. With a query-only collection + or a periodic task that runs the populate_collection method. + """ collection, _created = SourceImageCollection.objects.get_or_create( name="All Images", project=project, method="full", - # @TODO make this a dynamic collection that updates automatically ) logger.info(f"Created default collection for project {project}") return collection @@ -196,6 +202,7 @@ class ProjectFeatureFlags(pydantic.BaseModel): """ tags: bool = False # Whether the project supports tagging taxa + auto_processs_manual_uploads: bool = False # Whether to automatically process uploaded images default_feature_flags = ProjectFeatureFlags() @@ -233,6 +240,7 @@ class Project(ProjectSettingsMixin, BaseModel): jobs: models.QuerySet["Job"] sourceimage_collections: models.QuerySet["SourceImageCollection"] processing_services: models.QuerySet["ProcessingService"] + pipelines: models.QuerySet["Pipeline"] tags: models.QuerySet["Tag"] objects = ProjectManager() @@ -1373,11 +1381,37 @@ def validate_filename_timestamp(filename: str) -> None: raise ValidationError("Image filename does not contain a valid timestamp (e.g. YYYYMMDDHHMMSS-snapshot.jpg).") -def create_source_image_from_upload(image: ImageFieldFile, deployment: Deployment, request=None) -> "SourceImage": +def create_source_image_from_upload( + image: ImageFieldFile, + deployment: Deployment, + request=None, + process_now=True, +) -> "SourceImage": """Create a complete SourceImage from an uploaded file.""" - # md5 checksum from file - checksum = hashlib.md5(image.read()).hexdigest() - checksum_algorithm = "md5" + + # Read file content once + image.seek(0) + file_content = image.read() + + # Calculate a checksum for the image content + checksum, checksum_algorithm = calculate_file_checksum(file_content) + + # Create PIL image from file content (no additional file reads) + image_stream = BytesIO(file_content) + pil_image = PIL.Image.open(image_stream) + + timestamp = extract_timestamp(filename=image.name, image=pil_image) + if not timestamp: + logger.warning( + "A valid timestamp could not be found in the image's EXIF data or filename. " + "Please rename the file to include a timestamp " + "(e.g. YYYYMMDDHHMMSS-snapshot.jpg). " + "Falling back to the current time for the image captured timestamp." + ) + timestamp = timezone.now() + width = pil_image.width + height = pil_image.height + size = len(file_content) # get full public media url of image: if request: @@ -1385,23 +1419,26 @@ def create_source_image_from_upload(image: ImageFieldFile, deployment: Deploymen else: base_url = settings.MEDIA_URL - source_image = SourceImage( + source_image = SourceImage.objects.create( path=image.name, # Includes relative path from MEDIA_ROOT public_base_url=base_url, # @TODO how to merge this with the data source? project=deployment.project, deployment=deployment, - timestamp=None, # Will be calculated from filename or EXIF data on save + timestamp=timestamp, event=None, # Will be assigned when the image is grouped into events - size=image.size, + size=size, checksum=checksum, checksum_algorithm=checksum_algorithm, - width=image.width, - height=image.height, + width=width, + height=height, test_image=True, uploaded_by=request.user if request else None, ) - source_image.save() - deployment.save() + deployment.save(regroup_async=False) + if process_now: + from ami.ml.orchestration.processing import process_single_source_image + + process_single_source_image(source_image=source_image) return source_image @@ -1418,7 +1455,7 @@ class SourceImageUpload(BaseModel): The SourceImageViewSet will create a SourceImage from the uploaded file and delete the upload. """ - image = models.ImageField(upload_to=upload_to_with_deployment, validators=[validate_filename_timestamp]) + image = models.ImageField(upload_to=upload_to_with_deployment) user = models.ForeignKey(User, on_delete=models.SET_NULL, null=True, blank=True) deployment = models.ForeignKey(Deployment, on_delete=models.CASCADE, related_name="manually_uploaded_captures") source_image = models.OneToOneField( diff --git a/ami/main/tests.py b/ami/main/tests.py index 06d5d6783..445960ba6 100644 --- a/ami/main/tests.py +++ b/ami/main/tests.py @@ -1618,7 +1618,11 @@ def _test_sourceimageupload_permissions(self, user, permission_map): # --- Test Create --- response = self.client.post( list_url, - {"image": self._create_source_image_upload_file(), "deployment": self.deployment.id}, + { + "image": self._create_source_image_upload_file(), + "deployment": self.deployment.pk, + "project_id": self.project.pk, + }, format="multipart", ) diff --git a/ami/ml/models/pipeline.py b/ami/ml/models/pipeline.py index 5da9e14e5..65ad9e02c 100644 --- a/ami/ml/models/pipeline.py +++ b/ami/ml/models/pipeline.py @@ -5,6 +5,7 @@ if TYPE_CHECKING: from ami.ml.models import ProcessingService, ProjectPipelineConfig from ami.jobs.models import Job + from ami.main.models import Project import collections import dataclasses @@ -886,6 +887,40 @@ class PipelineStage(ConfigurableStage): """A configurable stage of a pipeline.""" +class PipelineQuerySet(models.QuerySet): + """Custom QuerySet for Pipeline model.""" + + def enabled(self, project: Project) -> PipelineQuerySet: + """ + Return pipelines that are enabled for a given project. + + # @TODO how can this automatically filter based on the pipeline's projects + # or the current query without having to specify the project? (e.g. with OuterRef?) + """ + return self.filter( + projects=project, + project_pipeline_configs__enabled=True, + project_pipeline_configs__project=project, + processing_services__projects=project, + ).distinct() + + def online(self, project: Project) -> PipelineQuerySet: + """ + Return pipelines that are available at least one online processing service. + """ + return self.filter( + processing_services__projects=project, + processing_services__last_checked_live=True, + ).distinct() + + +class PipelineManager(models.Manager): + """Custom Manager for Pipeline model.""" + + def get_queryset(self) -> PipelineQuerySet: + return PipelineQuerySet(self.model, using=self._db) + + @typing.final class Pipeline(BaseModel): """A pipeline of algorithms""" @@ -917,6 +952,9 @@ class Pipeline(BaseModel): "and the processing service." ), ) + + objects = PipelineManager() + processing_services: models.QuerySet[ProcessingService] project_pipeline_configs: models.QuerySet[ProjectPipelineConfig] jobs: models.QuerySet[Job] diff --git a/ami/ml/orchestration/__init__.py b/ami/ml/orchestration/__init__.py new file mode 100644 index 000000000..d05bbbd82 --- /dev/null +++ b/ami/ml/orchestration/__init__.py @@ -0,0 +1 @@ +from .processing import * # noqa: F401, F403 diff --git a/ami/ml/orchestration/pipelines.py b/ami/ml/orchestration/pipelines.py new file mode 100644 index 000000000..360434e35 --- /dev/null +++ b/ami/ml/orchestration/pipelines.py @@ -0,0 +1,24 @@ +from django.db import models + +from ami.main.models import Project +from ami.ml.models.pipeline import Pipeline + + +def get_default_pipeline(project: Project) -> Pipeline | None: + """ + Select a default pipeline to use for processing images in a project. + + This is a placeholder function that selects the pipeline with the most categories + and which is enabled for the project. + + @TODO use project settings to determine the default pipeline + """ + default_pipeline = project.default_processing_pipeline or ( + Pipeline.objects.all() + .enabled(project=project) # type: ignore + .online(project=project) # type: ignore + .annotate(num_categories=models.Count("algorithms__category_map__labels")) + .order_by("-num_categories", "-created_at") + .first() + ) + return default_pipeline diff --git a/ami/ml/orchestration/processing.py b/ami/ml/orchestration/processing.py new file mode 100644 index 000000000..c09b9aa56 --- /dev/null +++ b/ami/ml/orchestration/processing.py @@ -0,0 +1,45 @@ +import typing + +from ami.jobs.models import Job +from ami.ml.models import Pipeline +from ami.ml.orchestration.pipelines import get_default_pipeline + +if typing.TYPE_CHECKING: + from ami.main.models import SourceImage + + +def process_single_source_image( + source_image: "SourceImage", + pipeline: "Pipeline | None" = None, + run_async=True, +) -> "Job": + """ + Process a single SourceImage immediately. + """ + + assert source_image.deployment is not None, "SourceImage must belong to a deployment" + + if not source_image.event: + source_image.deployment.save(regroup_async=False) + source_image.refresh_from_db() + assert source_image.event is not None, "SourceImage must belong to an event" + + project = source_image.project + assert project is not None, "SourceImage must belong to a project" + + pipeline_choice = pipeline or get_default_pipeline(project) + assert pipeline_choice is not None, "Project must have a pipeline to run" + + # @TODO add images to a queue without creating a job for each image + job = Job.objects.create( + name=f"Capture #{source_image.pk} ({source_image.timestamp}) from {source_image.deployment.name}", + job_type_key="ml", + source_image_single=source_image, + pipeline=pipeline_choice, + project=project, + ) + if run_async: + job.enqueue() + else: + job.run() + return job diff --git a/ami/utils/media.py b/ami/utils/media.py new file mode 100644 index 000000000..449b45c43 --- /dev/null +++ b/ami/utils/media.py @@ -0,0 +1,124 @@ +import hashlib +import logging +from datetime import datetime + +from PIL import Image +from PIL.ExifTags import TAGS + +from ami.utils.dates import get_image_timestamp_from_filename + +logger = logging.getLogger(__name__) + + +def extract_timestamp_from_exif(image: Image.Image) -> datetime | None: + """ + Extract timestamp from EXIF data using existing Pillow image object. + + NOTE: This function explicitly strips timezone information and returns + naive datetime objects representing the local time when the photo was taken. + We only care about the local timestamp, not the timezone. + + Args: + image: PIL Image object + + Returns: + datetime: Naive datetime parsed from EXIF DateTimeOriginal (timezone stripped), + or None if not found + """ + try: + image.seek(0) # Ensure we read from the start of the image file + exif_data = image.getexif() + + if not exif_data: + logger.info("No EXIF data found in image") + return None + + # Convert tag IDs to readable names + exif_dict = {} + for tag_id, value in exif_data.items(): + tag_name = TAGS.get(tag_id, tag_id) + exif_dict[tag_name] = value + + # Try multiple timestamp fields in order of preference + timestamp_fields = [ + "DateTimeOriginal", # When photo was taken (preferred) + "DateTime", # When file was last modified + ] + + for field in timestamp_fields: + if field in exif_dict: + timestamp_str = exif_dict[field] + logger.debug(f"Found EXIF timestamp in {field}: {timestamp_str}") + + try: + # Parse EXIF datetime format: "YYYY:MM:DD HH:MM:SS" + # Note: EXIF timestamps are typically timezone-naive anyway, + # but we explicitly ensure we return a naive datetime + parsed_timestamp = datetime.strptime(timestamp_str, "%Y:%m:%d %H:%M:%S") + + # Explicitly strip timezone if somehow present (should be rare) + naive_timestamp = parsed_timestamp.replace(tzinfo=None) + + logger.info(f"Successfully parsed EXIF timestamp (timezone stripped): {naive_timestamp}") + return naive_timestamp + + except ValueError as e: + logger.warning(f"Failed to parse timestamp '{timestamp_str}' from {field}: {e}") + continue + + logger.info("No valid EXIF timestamp found in image") + return None + + except Exception as e: + logger.error(f"Error extracting EXIF timestamp: {e}") + return None + + +def extract_timestamp(filename: str, image: Image.Image | None = None) -> datetime | None: + """ + Extract timestamp from filename or EXIF data of an image. + + Args: + filename: Name of the file to extract timestamp from + image: Optional PIL Image object to extract EXIF timestamp from + + Returns: + datetime: Naive datetime object representing the timestamp, or None if not found + """ + # First try to get timestamp from filename + timestamp = get_image_timestamp_from_filename(filename) + if timestamp: + logger.info(f"Extracted timestamp from filename: {timestamp}") + return timestamp + + # If no valid timestamp from filename, try EXIF data if image is provided + if image: + exif_timestamp = extract_timestamp_from_exif(image) + if exif_timestamp: + logger.info(f"Extracted timestamp from EXIF data: {exif_timestamp}") + return exif_timestamp + + logger.warning("No valid timestamp found in filename or EXIF data") + return None + + +def calculate_file_checksum(file_content: bytes, algorithm: str = "md5") -> tuple[str, str]: + """ + Calculate checksum for file content. + + Args: + file_content: Raw file bytes + algorithm: Hash algorithm to use ("md5", "sha256", etc.) + + Returns: + tuple: (checksum_hex_string, algorithm_name) + """ + if algorithm.lower() == "md5": + checksum = hashlib.md5(file_content).hexdigest() + elif algorithm.lower() == "sha256": + checksum = hashlib.sha256(file_content).hexdigest() + else: + raise ValueError(f"Unsupported hash algorithm: {algorithm}") + + logger.debug(f"Calculated {algorithm} checksum: {checksum}") + return checksum, algorithm.lower() diff --git a/ui/src/pages/job-details/job-details-form/job-details-form.tsx b/ui/src/pages/job-details/job-details-form/job-details-form.tsx index aa333b254..4d915d926 100644 --- a/ui/src/pages/job-details/job-details-form/job-details-form.tsx +++ b/ui/src/pages/job-details/job-details-form/job-details-form.tsx @@ -43,7 +43,7 @@ const config: FormConfig = { label: translate(STRING.FIELD_LABEL_PIPELINE), }, sourceImages: { - label: translate(STRING.FIELD_LABEL_SOURCE_IMAGES), + label: translate(STRING.FIELD_LABEL_SOURCE_IMAGES_COLLECTION), }, startNow: { label: 'Start immediately', diff --git a/ui/src/pages/job-details/job-details.tsx b/ui/src/pages/job-details/job-details.tsx index d705e125a..71401b7ef 100644 --- a/ui/src/pages/job-details/job-details.tsx +++ b/ui/src/pages/job-details/job-details.tsx @@ -132,7 +132,7 @@ const JobSummary = ({ job }: { job: Job }) => { /> ) : job.sourceImages ? ( TableColumn[] = ( { id: 'source-image-collection', sortField: 'source_image_collection', - name: translate(STRING.FIELD_LABEL_SOURCE_IMAGES), + name: translate(STRING.FIELD_LABEL_SOURCE_IMAGES_COLLECTION), renderCell: (item: Job) => item.sourceImages ? (