diff --git a/ami/main/admin.py b/ami/main/admin.py index 684ec9b54..b39858621 100644 --- a/ami/main/admin.py +++ b/ami/main/admin.py @@ -396,7 +396,7 @@ class OccurrenceAdmin(admin.ModelAdmin[Occurrence]): "determination__rank", "created_at", ) - search_fields = ("determination__name", "determination__search_names") + search_fields = ("determination__name", "determination__search_names", "pk") def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: qs = super().get_queryset(request) @@ -418,10 +418,33 @@ def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: def detections_count(self, obj) -> int: return obj.detections_count + @admin.action(description="Apply class masking to occurrences") + def apply_class_mask(self, request: HttpRequest, queryset: QuerySet[Occurrence]) -> QuerySet[Occurrence]: + """ + Apply class masking to the queryset. + This is a placeholder for the actual implementation. + """ + from ami.main.models import TaxaList + from ami.ml.models import Algorithm + from ami.ml.post_processing import class_masking + + taxa_list = TaxaList.objects.get(pk=5) + algorithm = Algorithm.objects.get(pk=11) + + for occurrence in queryset: + class_masking.update_single_occurrence( + occurrence=occurrence, + algorithm=algorithm, + taxa_list=taxa_list, + ) + return queryset + + actions = [apply_class_mask] + ordering = ("-created_at",) # Add classifications as inline - inlines = [DetectionInline] + # inlines = [DetectionInline] @admin.register(Classification) diff --git a/ami/main/management/commands/import_trapdata_project.py b/ami/main/management/commands/import_trapdata_project.py index fa43829a1..04440c8a1 100644 --- a/ami/main/management/commands/import_trapdata_project.py +++ b/ami/main/management/commands/import_trapdata_project.py @@ -1,140 +1,286 @@ import datetime import json +import pydantic from dateutil.parser import parse as parse_date -from django.core.management.base import BaseCommand, CommandError # noqa +from django.core.management.base import BaseCommand -from ...models import Algorithm, Classification, Deployment, Detection, Event, Occurrence, Project, SourceImage, Taxon +from ami.main.models import Classification, Deployment, Detection, Event, Occurrence, Project, SourceImage, Taxon +from ami.ml.models import Algorithm + + +class IncomingDetection(pydantic.BaseModel): + id: int + source_image_id: int + source_image_path: str + source_image_width: int + source_image_height: int + source_image_filesize: int + label: str + score: float + cropped_image_path: str | None = None + sequence_id: str | None = None # This is the Occurrence ID on the ADC side (= detections in a sequence) + timestamp: datetime.datetime + detection_algorithm: str | None = None # Name of the object detection algorithm used + classification_algorithm: str | None = None # Classification algorithm used to generate the label & score + bbox: list[int] # Bounding box in the format [x_min, y_min, x_max, y_max] + + +class IncomingOccurrence(pydantic.BaseModel): + id: str + label: str + best_score: float + start_time: datetime.datetime + end_time: datetime.datetime + duration: datetime.timedelta + deployment: str + event: str + num_frames: int + # cropped_image_path: pathlib.Path + # source_image_id: int + examples: list[ + IncomingDetection + ] # These are the individual detections with source image data, bounding boxes and predictions + example_crop: str | None = None + # detections: list[object] + # deployment: object + # captures: list[object] class Command(BaseCommand): r"""Import trap data from a JSON file exported from the AMI data companion. - occurrences.json - [ + occurrences.json + + # CURRENT EXAMPLE JSON STRUCTURE: + { + "id":"SEQ-91", + "label":"Azochis rufidiscalis", + "best_score":0.4857344627, + "start_time":"2023-01-25T03:49:59.000", + "end_time":"2023-01-25T03:49:59.000", + "duration":"P0DT0H0M0S", + "deployment":"snapshots", + "event":"2023-01-24", + "num_frames":1, + "examples":[ + { + "id":91, + "source_image_id":402, + "source_image_path":"2023_01_24\/257-20230125034959-snapshot.jpg", + "source_image_width":4096, + "source_image_height":2160, + "source_image_filesize":1276685, + "label":"Azochis rufidiscalis", + "score":0.4857344627, + "cropped_image_path":"\/media\/michael\/ZWEIBEL\/ami-ml-data\/trapdata\/crops\/820709c454b529d5cf44e59fea1f4b5b.jpg", + "sequence_id":"20230124-SEQ-91", + "timestamp":"2023-01-25T03:49:59.000", + "bbox":[ + 2191, + 413, + 2568, + 638 + ] + } + ], + "example_crop":null + }, { - "id":"20220620-SEQ-207259", - "label":"Baileya ophthalmica", - "best_score":0.6794486046, - "start_time":"2022-06-21T09:23:00.000Z", - "end_time":"2022-06-21T09:23:00.000Z", + "id":"SEQ-86", + "label":"Sphinx canadensis", + "best_score":0.4561957121, + "start_time":"2023-01-24T20:11:59.000", + "end_time":"2023-01-24T20:11:59.000", "duration":"P0DT0H0M0S", - "deployment":"Vermont-Snapshots-Sample", - "event":{ - "id":19, - "day":"2022-06-20T00:00:00.000", - "url":null - }, + "deployment":"snapshots", + "event":"2023-01-24", "num_frames":1, "examples":[ { - "id":207259, - "source_image_id":15050, - "source_image_path":"2022_06_21_snapshots\/20220621052300-301-snapshot.jpg", + "id":86, + "source_image_id":88, + "source_image_path":"2023_01_24\/55-20230124201159-snapshot.jpg", "source_image_width":4096, "source_image_height":2160, - "source_image_filesize":1599836, - "label":"Baileya ophthalmica", - "score":0.6794486046, - "cropped_image_path":"exports\/occurrences_images\/20220620-SEQ-207259-963edb524a59504392d4bec06717857a.jpg", - "sequence_id":"20220620-SEQ-207259", - "timestamp":"2022-06-21T09:23:00.000Z", + "source_image_filesize":1013757, + "label":"Sphinx canadensis", + "score":0.4561957121, + "cropped_image_path":"\/media\/michael\/ZWEIBEL\/ami-ml-data\/trapdata\/crops\/839fd6565461939ef946751b87003eda.jpg", + "sequence_id":"20230124-SEQ-86", + "timestamp":"2023-01-24T20:11:59.000", "bbox":[ - 3598, - 1074, - 3821, - 1329 + 1629, + 0, + 1731, + 25 ] } ], - "url":null + "example_crop":null }, - ] """ help = "Import trap data from AMI data manager occurrences.json file" def add_arguments(self, parser): parser.add_argument("occurrences", type=str) + parser.add_argument("project_id", type=str, help="Project to import to") def handle(self, *args, **options): occurrences = json.load(open(options["occurrences"])) + project_id = options["project_id"] + + project = Project.objects.get(pk=project_id) + self.stdout.write(self.style.SUCCESS('Importing to project "%s"' % project.name)) - project, created = Project.objects.get_or_create(name="Default Project") + """ + -) Collect all Deployments that need to be created or fetched + -) Collect all SourceImages that need to be created or fetched + -) Collect all Occurrences that need to be created or fetched + -) Create Deployments, linking them to the correct Project + -) Create SourceImages, linking them to the correct Occurrence and Deployment + -) Create Occurrences, linking them to the correct Deployment and Project + -) Generate events (save deployments to trigger event generation) + -) Create Detections, linking them to the correct Occurrence and SourceImage + -) Create Classifications, linking them to the correct Detection and Taxon + -) commit transaction, if transaction is possible + """ + + # Create a fallback algorithm for detections missing algorithm info + default_classification_algorithm, created = Algorithm.objects.get_or_create( + name="Unknown classifier from ADC import", + task_type="classification", + defaults={ + "description": "Unknown classification model imported from AMI data companion occurrences.json", + "version": 0, + }, + ) if created: - self.stdout.write(self.style.SUCCESS('Successfully created project "%s"' % project)) - algorithm, created = Algorithm.objects.get_or_create(name="Latest Model", version="1.0") - for occurrence in occurrences: + self.stdout.write( + self.style.SUCCESS('Created fallback algorithm "%s"' % default_classification_algorithm.name) + ) + default_detection_algorithm, created = Algorithm.objects.get_or_create( + name="Unknown object detector from ADC import", + task_type="localization", + defaults={ + "description": "Unknown object detection model imported from AMI data companion occurrences.json", + "version": 0, + }, + ) + + # Process each occurrence from the JSON file + for occurrence_data in occurrences: + # Get or create deployment deployment, created = Deployment.objects.get_or_create( - name=occurrence["deployment"], + name=occurrence_data["deployment"], project=project, ) if created: self.stdout.write(self.style.SUCCESS('Successfully created deployment "%s"' % deployment)) - event, created = Event.objects.get_or_create( - start=parse_date(occurrence["event"]["day"]), - deployment=deployment, - ) + # Get or create taxon for the occurrence + best_taxon, created = Taxon.objects.get_or_create(name=occurrence_data["label"]) if created: - self.stdout.write(self.style.SUCCESS('Successfully created event "%s"' % event)) + self.stdout.write(self.style.SUCCESS('Successfully created taxon "%s"' % best_taxon)) - best_taxon, created = Taxon.objects.get_or_create(name=occurrence["label"]) - occ = Occurrence.objects.create( - event=event, + # Create occurrence + occurrence = Occurrence.objects.create( + event=None, # will be assigned when events are grouped deployment=deployment, project=project, determination=best_taxon, + determination_score=occurrence_data["best_score"], ) - self.stdout.write(self.style.SUCCESS('Successfully created occurrence "%s"' % occ)) + self.stdout.write(self.style.SUCCESS('Successfully created occurrence "%s"' % occurrence)) - for example in occurrence["examples"]: + # Process each detection example in the occurrence + for example in occurrence_data["examples"]: try: + # Create or get source image image, created = SourceImage.objects.get_or_create( path=example["source_image_path"], - timestamp=parse_date(example["timestamp"]), - event=event, deployment=deployment, - width=example["source_image_width"], - height=example["source_image_height"], - size=example["source_image_filesize"], + defaults={ + "timestamp": parse_date(example["timestamp"]), + "event": None, # will be assigned when events are calculated + "project": project, + "width": example["source_image_width"], + "height": example["source_image_height"], + "size": example["source_image_filesize"], + }, ) if created: self.stdout.write(self.style.SUCCESS('Successfully created image "%s"' % image)) + except KeyError as e: - self.stdout.write(self.style.ERROR('Error creating image "%s"' % e)) - image = None - - if image: - detection, created = Detection.objects.get_or_create( - occurrence=occ, - source_image=image, - timestamp=parse_date(example["timestamp"]), - path=example["cropped_image_path"], - bbox=example["bbox"], - ) - if created: - self.stdout.write(self.style.SUCCESS('Successfully created detection "%s"' % detection)) - else: - detection = None - - taxon, created = Taxon.objects.get_or_create(name=example["label"]) - - if detection: - one_day_later = datetime.timedelta(seconds=60 * 60 * 24) - classification, created = Classification.objects.get_or_create( - score=example["score"], - determination=taxon, - detection=detection, - type="machine", - algorithm=algorithm, - timestamp=parse_date(example["timestamp"]) + one_day_later, - ) - if created: - self.stdout.write( - self.style.SUCCESS('Successfully created classification "%s"' % classification) + self.stdout.write(self.style.ERROR('Error creating image - missing field: "%s"' % e)) + continue + + # Create detection + detection, created = Detection.objects.get_or_create( + occurrence=occurrence, + source_image=image, + bbox=example["bbox"], + defaults={ + "path": example.get("cropped_image_path"), + "timestamp": parse_date(example["timestamp"]), + }, + ) + if created: + self.stdout.write(self.style.SUCCESS('Successfully created detection "%s"' % detection)) + + # Get or create taxon for this specific detection + detection_taxon, created = Taxon.objects.get_or_create(name=example["label"]) + if created: + self.stdout.write(self.style.SUCCESS('Successfully created taxon "%s"' % detection_taxon)) + + # Determine which algorithm to use + algorithm_to_use = default_classification_algorithm + if example.get("classification_algorithm"): + # Try to find an algorithm with this name + try: + algorithm_to_use = Algorithm.objects.get(name=example["classification_algorithm"]) + except Algorithm.DoesNotExist: + # Create new algorithm if it doesn't exist + algorithm_to_use, created = Algorithm.objects.get_or_create( + name=example["classification_algorithm"], + task_type="classification", + defaults={ + "description": "Algorithm imported from AMI data companion: " + f"{example['classification_algorithm']}", + "version": 0, + }, ) + if created: + self.stdout.write(self.style.SUCCESS('Created algorithm "%s"' % algorithm_to_use.name)) + + # Create classification + classification, created = Classification.objects.get_or_create( + detection=detection, + algorithm=algorithm_to_use, + taxon=detection_taxon, + defaults={ + "score": example["score"], + "timestamp": parse_date(example["timestamp"]), + "terminal": True, + }, + ) + if created: + self.stdout.write(self.style.SUCCESS('Successfully created classification "%s"' % classification)) - # Update event start and end times based on the first and last detections - for event in Event.objects.all(): + # Regroup images into events for all deployments that were modified + self.stdout.write(self.style.SUCCESS("Regrouping images into events...")) + deployments_to_update = Deployment.objects.filter(project=project) + for deployment in deployments_to_update: + deployment.save(regroup_async=False) + self.stdout.write(self.style.SUCCESS('Updated events for deployment "%s"' % deployment)) + + # Update event timestamps + events_updated = 0 + for event in Event.objects.filter(project=project): event.save() + events_updated += 1 + + self.stdout.write(self.style.SUCCESS("Updated %d events" % events_updated)) + self.stdout.write(self.style.SUCCESS("Import completed successfully!")) diff --git a/ami/main/models.py b/ami/main/models.py index ce5fbadfa..66eb5e440 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -1163,12 +1163,25 @@ def group_images_into_events( defaults={"start": start_date, "end": end_date}, ) events.append(event) - SourceImage.objects.filter(deployment=deployment, timestamp__in=group).update(event=event) + source_images = SourceImage.objects.filter(deployment=deployment, timestamp__in=group) + source_images.update(event=event) + event.save() # Update start and end times and other cached fields logger.info( f"Created/updated event {event} with {len(group)} images for deployment {deployment}. " f"Duration: {event.duration_label()}" ) + # Update occurrences to point to the new event + occurrences_updated = ( + Occurrence.objects.filter( + detections__source_image__in=source_images, + ) + .exclude(event=event) + .update(event=event) + ) + logger.info( + f"Updated {occurrences_updated} occurrences to point to event {event} for deployment {deployment}." + ) logger.info( f"Done grouping {len(image_timestamps)} captures into {len(events)} events " f"for deployment {deployment}" @@ -1183,6 +1196,19 @@ def group_images_into_events( logger.info(f"Setting image dimensions for event {event}") set_dimensions_for_collection(event) + # Warn if any occurrences belonging to the deployment are not assigned to an event + logger.info("Checking for ungrouped occurrences in deployment") + ungrouped_occurrences = Occurrence.objects.filter( + deployment=deployment, + event__isnull=True, + ) + if ungrouped_occurrences.exists(): + logger.warning( + f"Found {ungrouped_occurrences.count()} occurrences in deployment {deployment} " + "that are not assigned to any event. " + "This may indicate that some images were not grouped correctly." + ) + logger.info("Updating relevant cached fields on deployment") deployment.events_count = len(events) deployment.save(update_calculated_fields=False, update_fields=["events_count"]) diff --git a/ami/ml/management/commands/import_pipeline_results.py b/ami/ml/management/commands/import_pipeline_results.py new file mode 100644 index 000000000..0ae274356 --- /dev/null +++ b/ami/ml/management/commands/import_pipeline_results.py @@ -0,0 +1,119 @@ +import json +from pathlib import Path + +from django.core.management.base import BaseCommand, CommandError +from django.db import transaction + +from ami.main.models import Project +from ami.ml.models.pipeline import save_results +from ami.ml.schemas import PipelineResultsResponse + + +class Command(BaseCommand): + help = "Import pipeline results from a JSON file into the database" + + def add_arguments(self, parser): + parser.add_argument("json_file", type=str, help="Path to JSON file containing PipelineResultsResponse data") + parser.add_argument("--project", type=int, required=True, help="Project ID to import the data into") + parser.add_argument("--dry-run", action="store_true", help="Validate the data without saving to database") + parser.add_argument( + "--public-base-url", + type=str, + help="Base URL for images if paths are relative (e.g., http://0.0.0.0:7070/)", + ) + + def handle(self, *args, **options): + json_file_path = Path(options["json_file"]) + project_id = options["project"] + dry_run = options.get("dry_run", False) + public_base_url = options.get("public_base_url") + + # Validate that the JSON file exists + if not json_file_path.exists(): + raise CommandError(f"JSON file does not exist: {json_file_path}") + + # Validate that the project exists + try: + project = Project.objects.get(pk=project_id) + except Project.DoesNotExist: + raise CommandError(f"Project with ID {project_id} does not exist") + + self.stdout.write(f"Reading JSON file: {json_file_path}") + + # Read and parse the JSON file + try: + with open(json_file_path, encoding="utf-8") as f: + json_data = json.load(f) + except json.JSONDecodeError as e: + raise CommandError(f"Invalid JSON in file {json_file_path}: {e}") + except Exception as e: + raise CommandError(f"Error reading file {json_file_path}: {e}") + + # Validate the JSON data against the PipelineResultsResponse schema + try: + pipeline_results = PipelineResultsResponse(**json_data) + except Exception as e: + raise CommandError(f"Invalid PipelineResultsResponse data: {e}") + + self.stdout.write( + self.style.SUCCESS( + f"Successfully validated PipelineResultsResponse with:" + f"\n - Pipeline: {pipeline_results.pipeline}" + f"\n - Source images: {len(pipeline_results.source_images)}" + f"\n - Detections: {len(pipeline_results.detections)}" + f"\n - Algorithms: {len(pipeline_results.algorithms)}" + ) + ) + + if dry_run: + self.stdout.write(self.style.WARNING("Dry run mode - no data will be saved to database")) + return + + # Import the data using save_results function + self.stdout.write(f"Importing data into project: {project} (ID: {project_id})") + + try: + with transaction.atomic(): + # Call the save_results function with create_missing_source_images=True + results_json = pipeline_results.json() + result = save_results( + results_json=results_json, + job_id=None, + return_created=True, + create_missing_source_images=True, + project_id=project_id, + public_base_url=public_base_url, + ) + + if result: + self.stdout.write( + self.style.SUCCESS( + f"Successfully imported pipeline results:" + f"\n - Pipeline: {result.pipeline}" + f"\n - Source images processed: {len(result.source_images)}" + f"\n - Detections created: {len(result.detections)}" + f"\n - Classifications created: {len(result.classifications)}" + f"\n - Algorithms used: {len(result.algorithms)}" + f"\n - Total processing time: {result.total_time:.2f} seconds" + ) + ) + + # Re-save all deployments in the results to ensure they are up-to-date + # Must loop through the source images + self.stdout.write(self.style.SUCCESS("Updating sessions and stations")) + deployments = { + source_image.deployment for source_image in result.source_images if source_image.deployment + } + for deployment in deployments: + deployment.save(regroup_async=False) + else: + self.stdout.write(self.style.WARNING("Import completed but no result object returned")) + + except Exception as e: + raise CommandError(f"Error importing pipeline results: {e}") + + self.stdout.write( + self.style.SUCCESS( + f"Pipeline results successfully imported into project '{project.name}' (ID: {project_id})" + ) + ) diff --git a/ami/ml/management/commands/test_class_masking.py b/ami/ml/management/commands/test_class_masking.py new file mode 100644 index 000000000..2deeda1d9 --- /dev/null +++ b/ami/ml/management/commands/test_class_masking.py @@ -0,0 +1,53 @@ +from django.core.management.base import BaseCommand + +from ami.main.models import SourceImageCollection +from ami.ml.models import Algorithm + + +class Command(BaseCommand): + help = """ + Filter classifications by a provided taxa list + + # Usage: + docker compose run --rm django python manage.py test_class_masking --project 1 --taxa-list 1" \ + """ + + def add_arguments(self, parser): + parser.add_argument("--project", type=int, help="Project ID to process") + parser.add_argument("--collection", type=int, help="Source image collection ID to process") + parser.add_argument("--taxa-list", type=int, help="Taxa list ID to filter by") + parser.add_argument( + "--algorithm", type=int, help="Algorithm ID to use for filtering classifications (e.g. the global model)" + ) + parser.add_argument("--dry-run", action="store_true", help="Don't make any changes") + + def handle(self, *args, **options): + project_id = options["project"] + collection_id = options["collection"] + taxa_list_id = options["taxa_list"] + algorithm_id = options["algorithm"] + + from ami.main.models import Project, TaxaList + from ami.ml.post_processing.class_masking import update_occurrences_in_collection + + try: + project = Project.objects.get(id=project_id) + taxa_list = TaxaList.objects.get(id=taxa_list_id) + collection = SourceImageCollection.objects.get(id=collection_id) + algorithm = Algorithm.objects.get(id=algorithm_id) + except (Project.DoesNotExist, TaxaList.DoesNotExist) as e: + self.stdout.write(f"Error: {e}") + return + + self.stdout.write(f"Processing project: {project.name}, taxa list: {taxa_list.name}") + self.stdout.write("Filtering classifications based on the taxa list...") + # Log collection + self.stdout.write(f"Collection: {collection.name} (ID: {collection.pk})") + + update_occurrences_in_collection( + collection=collection, + taxa_list=taxa_list, + algorithm=algorithm, + params={}, + job=None, + ) diff --git a/ami/ml/management/commands/test_logistic_binning.py b/ami/ml/management/commands/test_logistic_binning.py new file mode 100644 index 000000000..7374dda0a --- /dev/null +++ b/ami/ml/management/commands/test_logistic_binning.py @@ -0,0 +1,218 @@ +import random + +import numpy as np +from django.core.management.base import BaseCommand +from django.db import models + +from ami.main.models import Classification, Taxon +from ami.ml.models import Algorithm + + +class Command(BaseCommand): + help = """ + Sample classifications by score quartiles and identify occurrences for human verification + + # Usage: + docker compose run --rm django python manage.py test_logistic_binning --project 1 --species "Apogeshna stenialis" \ + --algorithm 23 + """ + + def add_arguments(self, parser): + parser.add_argument("--project", type=int, help="Project ID to process") + parser.add_argument("--species", type=str, required=True, help="Species name to analyze scores for") + parser.add_argument("--algorithm", type=int, help="Algorithm ID to use (default: auto-select)") + parser.add_argument("--sample-size", type=int, default=1000, help="Initial random sample size") + parser.add_argument("--bin-sample-size", type=int, default=50, help="Sample size from each quartile bin") + parser.add_argument("--dry-run", action="store_true", help="Don't make any changes") + + def handle(self, *args, **options): + project_id = options["project"] + species_name = options["species"] + algorithm_id = options.get("algorithm") + sample_size = options["sample_size"] + bin_sample_size = options["bin_sample_size"] + dry_run = options["dry_run"] + + # Look up the target species + try: + target_taxon = Taxon.objects.get(name=species_name, active=True) + self.stdout.write(f"Target species: {target_taxon}") + except Taxon.DoesNotExist: + self.stdout.write(f"Species '{species_name}' not found. Exiting.") + return + + # Find the best algorithm if not specified + if algorithm_id: + try: + algorithm = Algorithm.objects.get(id=algorithm_id) + except Algorithm.DoesNotExist: + self.stdout.write(f"Algorithm with ID {algorithm_id} not found. Exiting.") + return + else: + # Find algorithm with most classifications that has task_type="classification" + algorithm = ( + Algorithm.objects.filter( + task_type="classification", + category_map__isnull=False, + classifications__detection__source_image__project_id=project_id, + ) + .annotate(classification_count=models.Count("classifications")) + .order_by("-classification_count") + .first() + ) + if not algorithm: + self.stdout.write("No suitable classification algorithm found. Exiting.") + return + + self.stdout.write(f"Using algorithm: {algorithm}") + + # Check if target species is in the algorithm's category map + if not algorithm.category_map: + self.stdout.write("Algorithm has no category map. Exiting.") + return + + try: + species_index = algorithm.category_map.labels.index(target_taxon.name) + self.stdout.write(f"Species '{target_taxon.name}' found at index {species_index} in category map") + except ValueError: + self.stdout.write(f"Species '{target_taxon.name}' not found in algorithm's category map. Exiting.") + return + + # Get all classifications for the project from this algorithm + classifications_qs = Classification.objects.select_related("detection__occurrence").filter( + detection__source_image__project_id=project_id, + algorithm=algorithm, + scores__isnull=False, + detection__occurrence__isnull=False, + ) + + total_classifications = classifications_qs.count() + self.stdout.write(f"Found {total_classifications} classifications from algorithm {algorithm.name}") + + if total_classifications == 0: + self.stdout.write("No classifications found. Exiting.") + return + + # Randomly sample classifications + sample_size = min(sample_size, total_classifications) + self.stdout.write(f"Randomly sampling {sample_size} classifications...") + + # Get random sample + sampled_classifications = list(classifications_qs.order_by("?")[:sample_size]) + + # Extract species-specific scores from the scores array + species_scores = [] + for classification in sampled_classifications: + if classification.scores and len(classification.scores) > species_index: + species_score = classification.scores[species_index] + if species_score is not None: + species_scores.append(species_score) + + if not species_scores: + self.stdout.write(f"No valid scores found for species '{target_taxon.name}'. Exiting.") + return + + self.stdout.write(f"Found {len(species_scores)} valid species-specific scores") + + # Calculate quartiles using species-specific scores + q1 = np.percentile(species_scores, 25) + q2 = np.percentile(species_scores, 50) # median + q3 = np.percentile(species_scores, 75) + + self.stdout.write(f"Species score quartiles: Q1={q1:.3f}, Q2={q2:.3f}, Q3={q3:.3f}") + + # Separate into bins based on species-specific scores + bins: dict[str, list[Classification]] = { + "Q1 (0-25%)": [], + "Q2 (25-50%)": [], + "Q3 (50-75%)": [], + "Q4 (75-100%)": [], + } + + for classification in sampled_classifications: + # Get species-specific score for this classification + if classification.scores and len(classification.scores) > species_index: + species_score = classification.scores[species_index] + if species_score is not None: + if species_score <= q1: + bins["Q1 (0-25%)"].append(classification) + elif species_score <= q2: + bins["Q2 (25-50%)"].append(classification) + elif species_score <= q3: + bins["Q3 (50-75%)"].append(classification) + else: + bins["Q4 (75-100%)"].append(classification) + + # Print bin statistics + self.stdout.write("\nBin statistics:") + for bin_name, bin_classifications in bins.items(): + self.stdout.write(f" {bin_name}: {len(bin_classifications)} classifications") + + # Sample from each bin + sampled_occurrences = set() + + self.stdout.write(f"\nSampling up to {bin_sample_size} classifications from each bin...") + + for bin_name, bin_classifications in bins.items(): + if not bin_classifications: + self.stdout.write(f" {bin_name}: No classifications to sample") + continue + + # Sample from this bin + bin_sample = random.sample(bin_classifications, min(bin_sample_size, len(bin_classifications))) + + # Extract occurrences + bin_occurrences = {c.detection.occurrence for c in bin_sample if c.detection and c.detection.occurrence} + sampled_occurrences.update(bin_occurrences) + + self.stdout.write( + f" {bin_name}: Sampled {len(bin_sample)} classifications -> {len(bin_occurrences)} occurrences" + ) + + # Print summary + self.stdout.write("\n=== SUMMARY ===") + self.stdout.write(f"Total unique occurrences for human verification: {len(sampled_occurrences)}") + + if sampled_occurrences: + # Group by determination for summary + determination_counts = {} + score_ranges = {"min": float("inf"), "max": float("-inf")} + + for occurrence in sampled_occurrences: + determination = occurrence.determination + determination_name = str(determination) if determination else "Undetermined" + determination_counts[determination_name] = determination_counts.get(determination_name, 0) + 1 + + if occurrence.determination_score: + score_ranges["min"] = min(score_ranges["min"], occurrence.determination_score) + score_ranges["max"] = max(score_ranges["max"], occurrence.determination_score) + + self.stdout.write("\nOccurrences by determination:") + for determination, count in sorted(determination_counts.items(), key=lambda x: x[1], reverse=True): + self.stdout.write(f" {determination}: {count}") + + if score_ranges["min"] != float("inf"): + self.stdout.write( + f"\nDetermination score range: {score_ranges['min']:.3f} - {score_ranges['max']:.3f}" + ) + + # Sample occurrence details for verification + sample_occurrences = list(sampled_occurrences)[:10] # Show first 10 as examples + self.stdout.write("\nSample occurrences for verification (first 10):") + for i, occurrence in enumerate(sample_occurrences, 1): + determination = occurrence.determination or "Undetermined" + score = f"{occurrence.determination_score:.3f}" if occurrence.determination_score else "N/A" + deployment = occurrence.deployment.name if occurrence.deployment else "Unknown" + self.stdout.write( + f" {i}. Occurrence #{occurrence.pk} - {determination} (score: {score}) - {deployment}" + ) + + if not dry_run: + self.stdout.write( + f"\n[DRY-RUN MODE] Would process {len(sampled_occurrences)} occurrences for verification" + ) + # TODO: Add logic here to tag occurrences for verification + else: + self.stdout.write( + f"\n[DRY-RUN MODE] Would process {len(sampled_occurrences)} occurrences for verification" + ) diff --git a/ami/ml/models/algorithm.py b/ami/ml/models/algorithm.py index 5a05d5352..d3146fb8e 100644 --- a/ami/ml/models/algorithm.py +++ b/ami/ml/models/algorithm.py @@ -3,9 +3,10 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from ami.main.models import Classification + from ami.main.models import Classification, Taxon from ami.ml.models import Pipeline +import logging import typing from django.contrib.postgres.fields import ArrayField @@ -14,6 +15,8 @@ from ami.base.models import BaseModel +logger = logging.getLogger(__name__) + @typing.final class AlgorithmCategoryMap(BaseModel): @@ -58,7 +61,9 @@ def get_category(self, label, label_field="label"): # Can use JSON containment operators return self.data.index(next(category for category in self.data if category[label_field] == label)) - def with_taxa(self, category_field="label", only_indexes: list[int] | None = None): + def with_taxa( + self, category_field="label", only_indexes: list[int] | None = None + ) -> list[dict[str, str | int | Taxon | None]]: """ Add Taxon objects to the category map, or None if no match @@ -69,7 +74,7 @@ def with_taxa(self, category_field="label", only_indexes: list[int] | None = Non @TODO consider creating missing taxa? """ - from ami.main.models import Taxon + from ami.main.models import Taxon, TaxonRank if only_indexes: labels_data = [self.data[i] for i in only_indexes] @@ -88,6 +93,14 @@ def with_taxa(self, category_field="label", only_indexes: list[int] | None = Non for category in labels_data: taxon = taxon_map.get(category[category_field]) + # Import all taxa in the category map that are not in the database yet + if not taxon: + taxon = Taxon.objects.create( + name=category["label"], + rank=category.get("taxon_rank", TaxonRank.SPECIES), # @TODO: make this flexible + ) + # @TODO: this doesn't seem to be working - logging works but no species are registered + logger.info(f"Registered new taxon {taxon}") category["taxon"] = taxon return labels_data diff --git a/ami/ml/models/pipeline.py b/ami/ml/models/pipeline.py index d2f6c48dc..af71ca4ac 100644 --- a/ami/ml/models/pipeline.py +++ b/ami/ml/models/pipeline.py @@ -41,6 +41,7 @@ AlgorithmConfigResponse, AlgorithmReference, ClassificationResponse, + DeploymentResponse, DetectionRequest, DetectionResponse, PipelineRequest, @@ -345,11 +346,12 @@ def get_or_create_algorithm_and_category_map( " Will attempt to create one from the classification results." ) + # @TODO update the unique constraint to use key & version instead of name & version algo, _created = Algorithm.objects.get_or_create( - key=algorithm_config.key, + name=algorithm_config.name, version=algorithm_config.version, defaults={ - "name": algorithm_config.name, + "key": algorithm_config.key, "task_type": algorithm_config.task_type, "version_name": algorithm_config.version_name, "uri": algorithm_config.uri, @@ -415,7 +417,7 @@ def get_or_create_detection( # A detection may have a pre-existing crop image URL or not. # If not, a new one will be created in a periodic background task. - if detection_resp.crop_image_url and detection_resp.crop_image_url.strip("/"): + if detection_resp.crop_image_url and detection_resp.crop_image_url.startswith(("http://", "https://")): # Ensure that the crop image URL is not empty or only a slash. None is fine. crop_url = detection_resp.crop_image_url else: @@ -768,6 +770,113 @@ def create_classifications( return existing_classifications + new_classifications +def get_or_create_deployments( + deployments_data: list[DeploymentResponse], + project_id: int, + logger: logging.Logger = logger, +) -> dict[str, Deployment]: + """ + Create or get deployments from source images data. + + :param source_images_data: List of source image dictionaries from raw JSON + :param project_id: Project ID to create deployments for + :param logger: Logger instance + + :return: Dictionary mapping deployment keys to Deployment objects + """ + from ami.main.models import Project, get_or_create_default_deployment + + project = Project.objects.get(pk=project_id) + deployments = {} + + for deployment_data in deployments_data: + deployment_name = deployment_data.name + + if deployment_name not in deployments: + deployment = get_or_create_default_deployment( + project=project, + name=deployment_name, + ) + deployments[deployment_name] = deployment + + return deployments + + +def create_source_images( + source_images_data: list[SourceImageResponse], + deployments: dict[str, Deployment], + project_id: int, + public_base_url: str | None = None, + logger: logging.Logger = logger, +) -> dict[str, int]: + """ + Create source images from pipeline results data. + + This assumes the IDs are external IDs from the pipeline results creator + and maps them to internal IDs in the database. + + This was created for an initial use case, needs to be tested for broader use cases. + + :param source_images_data: List of source image dictionaries from raw JSON + :param deployments: Dictionary mapping deployment keys to Deployment objects + :param project_id: Project ID + :param public_base_url: Base URL for images if paths are relative + :param logger: Logger instance + + :return: Dictionary mapping external IDs to internal source image IDs + """ + import ami.utils.dates + from ami.main.models import Project + + project = Project.objects.get(pk=project_id) + id_mapping = {} + + for source_image_data in source_images_data: + external_id = source_image_data.id + url = source_image_data.url + deployment_info = source_image_data.deployment + + if not deployment_info: + logger.warning( + f"The incoming source image {external_id} does not have a deployment specified. " + "This is required to create a SourceImage." + ) + continue + else: + deployment_name = deployment_info.name + deployment = deployments[deployment_name] + + # Check if source image already exists by URL and deployment + existing_image = SourceImage.objects.filter(deployment=deployment, path=url).first() + + if existing_image: + source_image = existing_image + logger.debug(f"Using existing source image {source_image.pk} for path: {url}") + else: + # Extract timestamp from filename + timestamp = ami.utils.dates.get_image_timestamp_from_filename(url) + + # Set public_base_url if provided and path is relative + final_public_base_url = None + if public_base_url and not url.startswith(("http://", "https://")): + final_public_base_url = public_base_url.rstrip("/") + + source_image = SourceImage.objects.create( + path=url, + deployment=deployment, + project=project, + timestamp=timestamp, + public_base_url=final_public_base_url, + ) + logger.info(f"Created source image {source_image.pk} for path: {url}") + + # Map external ID to internal ID + id_mapping[external_id] = str(source_image.pk) + logger.debug(f"Mapped external ID {external_id} to internal ID {source_image.pk}") + + return id_mapping + + def create_and_update_occurrences_for_detections( detections: list[Detection], logger: logging.Logger = logger, @@ -864,6 +973,10 @@ def save_results( results_json: str | None = None, job_id: int | None = None, return_created=False, + create_missing_source_images: bool = False, + project_id: int | None = None, + public_base_url: str | None = None, + create_new_algorithms: bool = False, ) -> PipelineSaveResults | None: """ Save results from ML pipeline API. @@ -897,6 +1010,57 @@ def save_results( results = PipelineResultsResponse.parse_obj(results.dict()) assert results, "No results from pipeline to save" + + # Create missing source images and deployments if requested + if create_missing_source_images and project_id: + job_logger.info(f"Creating missing source images and deployments for project {project_id}") + + deployments_data = results.deployments + source_images_data = results.source_images + + if not deployments_data: + job_logger.warning( + "No deployments data found in results. " + "New source images will not be created without deployments data." + ) + else: + get_or_create_deployments( + deployments_data=deployments_data, + project_id=project_id, + logger=job_logger, + ) + deployments_map = {dep.name: dep for dep in Deployment.objects.filter(project_id=project_id)} + job_logger.info(f"Found {len(deployments_map)} existing deployments for project {project_id}") + + if not source_images_data: + raise ValueError( + "No source images data found in results. " + "New detections cannot be created without source images data." + ) + + # Create source images from the external results data + # where the IDs do not match the internal IDs. + id_mapping = create_source_images( + source_images_data=source_images_data, + deployments=deployments_map, + project_id=project_id, + public_base_url=public_base_url, + logger=job_logger, + ) + + # Update the results to use internal IDs + for i, source_image_data in enumerate(source_images_data): + external_id = source_image_data.id + if external_id in id_mapping: + results.source_images[i].id = str(id_mapping[external_id]) + + # Update detection source_image_ids to use internal IDs + for detection in results.detections: + if detection.source_image_id in id_mapping: + detection.source_image_id = str(id_mapping[detection.source_image_id]) + + job_logger.debug(f"Created/found {len(id_mapping)} source images with ID mapping: {id_mapping}") + source_images = SourceImage.objects.filter(pk__in=[int(img.id) for img in results.source_images]).distinct() pipeline, _created = Pipeline.objects.get_or_create(slug=results.pipeline, defaults={"name": results.pipeline}) @@ -905,19 +1069,23 @@ def save_results( f"The pipeline returned by the ML backend was not recognized, created a placeholder: {pipeline}" ) - # Algorithms and category maps should be created in advance when registering the pipeline & processing service - # however they are also currently available in each pipeline results response as well. - # @TODO review if we should only use the algorithms from the pre-registered pipeline config instead of the results - algorithms_used = { - algo_key: get_or_create_algorithm_and_category_map(algo_config, logger=job_logger) - for algo_key, algo_config in results.algorithms.items() - } - # Add all algorithms initially reported in the pipeline response to the pipeline - for algo in algorithms_used.values(): - pipeline.algorithms.add(algo) + if create_new_algorithms: + # Algorithms and category maps should be created in advance when registering the pipeline & processing service + # however they are also currently available in each pipeline results response as well. + # @TODO review if we should only use the algorithms from the pre-registered pipeline config instead of + # the results + algorithms_used = { + algo_key: get_or_create_algorithm_and_category_map(algo_config, logger=job_logger) + for algo_key, algo_config in results.algorithms.items() + } + # Add all algorithms initially reported in the pipeline response to the pipeline + for algo in algorithms_used.values(): + pipeline.algorithms.add(algo) - algos_reported = [f" {algo.task_type}: {algo_key} ({algo})\n" for algo_key, algo in algorithms_used.items()] - job_logger.info(f"Algorithms reported in pipeline response: \n{''.join(algos_reported)}") + algos_reported = [f" {algo.task_type}: {algo_key} ({algo})\n" for algo_key, algo in algorithms_used.items()] + job_logger.info(f"Algorithms reported in pipeline response: \n{''.join(algos_reported)}") + else: + algorithms_used = {algo.key: algo for algo in pipeline.algorithms.all()} detections = create_detections( detections=results.detections, @@ -1160,7 +1328,7 @@ def choose_processing_service_for_pipeline( return processing_service_lowest_latency def process_images(self, images: typing.Iterable[SourceImage], project_id: int, job_id: int | None = None): - processing_service = self.choose_processing_service_for_pipeline(job_id, self.name, project_id) + processing_service = self.choose_processing_service_for_pipeline(job_id or 0, self.name, project_id) if not processing_service.endpoint_url: raise ValueError( diff --git a/ami/ml/post_processing/class_masking.py b/ami/ml/post_processing/class_masking.py new file mode 100644 index 000000000..81198d2c5 --- /dev/null +++ b/ami/ml/post_processing/class_masking.py @@ -0,0 +1,210 @@ +import logging + +from django.db.models import QuerySet +from django.utils import timezone + +from ami.main.models import Classification, Occurrence, SourceImageCollection, TaxaList +from ami.ml.models import Algorithm, AlgorithmCategoryMap + +logger = logging.getLogger(__name__) + + +def update_single_occurrence( + occurrence: Occurrence, + algorithm: Algorithm, + taxa_list: TaxaList, + task_logger: logging.Logger = logger, +): + task_logger.info(f"Recalculating classifications for occurrence {occurrence.pk}.") + + # Get the classifications for the occurrence in the collection + classifications = Classification.objects.filter( + detection__occurrence=occurrence, + terminal=True, + algorithm=algorithm, + scores__isnull=False, + ).distinct() + + make_classifications_filtered_by_taxa_list( + classifications=classifications, + taxa_list=taxa_list, + algorithm=algorithm, + ) + + +def update_occurrences_in_collection( + collection: SourceImageCollection, + taxa_list: TaxaList, + algorithm: Algorithm, + params: dict, + task_logger: logging.Logger = logger, + job=None, +): + task_logger.info(f"Recalculating classifications based on a taxa list. Params: {params}") + + # Make new AlgorithmCategoryMap with the taxa in the list + # @TODO + + classifications = Classification.objects.filter( + detection__source_image__collections=collection, + terminal=True, + # algorithm__task_type="classification", + algorithm=algorithm, + scores__isnull=False, + ).distinct() + + make_classifications_filtered_by_taxa_list( + classifications=classifications, + taxa_list=taxa_list, + algorithm=algorithm, + ) + + +def make_classifications_filtered_by_taxa_list( + classifications: QuerySet[Classification], + taxa_list: TaxaList, + algorithm: Algorithm, +): + taxa_in_list = taxa_list.taxa.all() + + occurrences_to_update: set[Occurrence] = set() + logger.info(f"Found {len(classifications)} terminal classifications with scores to update.") + + if not classifications: + raise ValueError("No terminal classifications with scores found to update.") + + if not algorithm.category_map: + raise ValueError(f"Algorithm {algorithm} does not have a category map.") + category_map: AlgorithmCategoryMap = algorithm.category_map + + # Consider moving this to a method on the Classification model + + # @TODO find a more efficient way to get the category map with taxa. This is slow! + logger.info(f"Retrieving category map with Taxa instances for algorithm {algorithm}") + category_map_with_taxa = category_map.with_taxa() + # Filter the category map to only include taxa that are in the taxa list + # included_category_map_with_taxa = [ + # category for category in category_map_with_taxa if category["taxon"] in taxa_in_list + # ] + excluded_category_map_with_taxa = [ + category for category in category_map_with_taxa if category["taxon"] not in taxa_in_list + ] + + # included_category_indices = [int(category["index"]) for category in category_map_with_taxa] + excluded_category_indices = [ + int(category["index"]) for category in excluded_category_map_with_taxa # type: ignore + ] + + # Log number of categories in the category map, num included, and num excluded, num classifications to update + logger.info( + f"Category map has {len(category_map_with_taxa)} categories, " + f"{len(excluded_category_map_with_taxa)} categories excluded, " + f"{len(classifications)} classifications to check" + ) + + classifications_to_add = [] + classifications_to_update = [] + + timestamp = timezone.now() + for classification in classifications: + scores, logits = classification.scores, classification.logits + # Set scores and logits to zero if they are not in the filtered category indices + + import numpy as np + + # Assert that all scores & logits are lists of numbers + if not isinstance(scores, list) or not all(isinstance(score, (int, float)) for score in scores): + raise ValueError(f"Scores for classification {classification.pk} are not a list of numbers: {scores}") + if not isinstance(logits, list) or not all(isinstance(logit, (int, float)) for logit in logits): + raise ValueError(f"Logits for classification {classification.pk} are not a list of numbers: {logits}") + + logger.debug(f"Processing classification {classification.pk} with {len(scores)} scores") + logger.info(f"Previous totals: {sum(scores)} scores, {sum(logits)} logits") + + # scores_np_filtered = np.array(scores) + logits_np = np.array(logits) + + # scores_np_filtered[excluded_category_indices] = 0.0 + + # @TODO can we use np.NAN instead of 0.0? zero will NOT calculate correctly in softmax. + # @TODO delete the excluded categories from the scores and logits instead of setting to 0.0 + # logits_np[excluded_category_indices] = 0.0 + # logits_np[excluded_category_indices] = np.nan + logits_np[excluded_category_indices] = -100 + + logits: list[float] = logits_np.tolist() + + from numpy import exp + from numpy import sum as np_sum + + # @TODO add test to see if this is correct, or needed! + # Recalculate the softmax scores based on the filtered logits + scores_np: np.ndarray = exp(logits_np - np.max(logits_np)) # Subtract max for numerical stability + scores_np /= np_sum(scores_np) # Normalize to get probabilities + + scores: list = scores_np.tolist() # Convert back to list + + logger.info(f"New totals: {sum(scores)} scores, {sum(logits)} logits") + + # Get the taxon with the highest score using the index of the max score + top_index = scores.index(max(scores)) + top_taxon = category_map_with_taxa[top_index][ + "taxon" + ] # @TODO: This doesn't work if the taxon has never been classified + print("Top taxon: ", category_map_with_taxa[top_index]) # @TODO: REMOVE + print("Top index: ", top_index) # @TODO: REMOVE + + # check if needs updating + if classification.scores == scores and classification.logits == logits: + logger.debug(f"Classification {classification.pk} does not need updating") + continue + + # Consider the existing classification as an intermediate classification + classification.terminal = False + classification.updated_at = timestamp + + # Recalculate the top taxon and score + new_classification = Classification( + taxon=top_taxon, + algorithm=classification.algorithm, + score=max(scores), + scores=scores, + logits=logits, + detection=classification.detection, + timestamp=classification.timestamp, + terminal=True, + category_map=None, # @TODO need a new category map with the filtered taxa + created_at=timestamp, + updated_at=timestamp, + ) + if new_classification.taxon is None: + raise (ValueError("Classification isn't registered yet. Aborting")) # @TODO remove or fail gracefully + + classifications_to_update.append(classification) + classifications_to_add.append(new_classification) + + assert new_classification.detection is not None + assert new_classification.detection.occurrence is not None + occurrences_to_update.add(new_classification.detection.occurrence) + + logging.info( + f"Adding new classification for Taxon {top_taxon} to occurrence {new_classification.detection.occurrence}" + ) + + # Bulk update the existing classifications + if classifications_to_update: + logger.info(f"Bulk updating {len(classifications_to_update)} existing classifications") + Classification.objects.bulk_update(classifications_to_update, ["terminal", "updated_at"]) + logger.info(f"Updated {len(classifications_to_update)} existing classifications") + + if classifications_to_add: + # Bulk create the new classifications + logger.info(f"Bulk creating {len(classifications_to_add)} new classifications") + Classification.objects.bulk_create(classifications_to_add) + logger.info(f"Added {len(classifications_to_add)} new classifications") + + # Update the occurrence determinations + logger.info(f"Updating the determinations for {len(occurrences_to_update)} occurrences") + for occurrence in occurrences_to_update: + occurrence.save(update_determination=True) + logger.info(f"Updated determinations for {len(occurrences_to_update)} occurrences") diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index ce473480d..894dc3785 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -124,9 +124,16 @@ class SourceImageRequest(pydantic.BaseModel): # b64: str | None = None +class DeploymentResponse(pydantic.BaseModel): + id: str | None = None + name: str + key: str | None = None + + class SourceImageResponse(pydantic.BaseModel): id: str url: str + deployment: DeploymentResponse | None = None class Config: extra = "ignore" @@ -193,6 +200,7 @@ class PipelineResultsResponse(pydantic.BaseModel): total_time: float source_images: list[SourceImageResponse] detections: list[DetectionResponse] + deployments: list[DeploymentResponse] | None = None errors: list | str | None = None