diff --git a/.github/workflows/test.backend.yml b/.github/workflows/test.backend.yml
index 28ca1cf53..e6d08e2c1 100644
--- a/.github/workflows/test.backend.yml
+++ b/.github/workflows/test.backend.yml
@@ -7,11 +7,11 @@ env:
on:
pull_request:
- branches: ["master", "main"]
+ branches: ["main", "deployments/*", "releases/*"]
paths-ignore: ["docs/**", "ui/**"]
push:
- branches: ["master", "main"]
+ branches: ["main", "deployments/*", "releases/*"]
paths-ignore: ["docs/**", "ui/**"]
concurrency:
diff --git a/.github/workflows/test.frontend.yml b/.github/workflows/test.frontend.yml
index 449a565f8..e93ed7b49 100644
--- a/.github/workflows/test.frontend.yml
+++ b/.github/workflows/test.frontend.yml
@@ -7,13 +7,13 @@ env:
on:
pull_request:
- branches: ["master", "main"]
+ branches: ["main", "deployments/*", "releases/*"]
paths:
- "!./**"
- "ui/**"
push:
- branches: ["master", "main"]
+ branches: ["main", "deployments/*", "releases/*"]
paths:
- "!./**"
- "ui/**"
diff --git a/ami/jobs/migrations/0017_alter_job_job_type_key.py b/ami/jobs/migrations/0017_alter_job_job_type_key.py
new file mode 100644
index 000000000..a1b74f46d
--- /dev/null
+++ b/ami/jobs/migrations/0017_alter_job_job_type_key.py
@@ -0,0 +1,29 @@
+# Generated by Django 4.2.10 on 2025-04-24 16:25
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("jobs", "0016_job_data_export_job_params_alter_job_job_type_key"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="job",
+ name="job_type_key",
+ field=models.CharField(
+ choices=[
+ ("ml", "ML pipeline"),
+ ("populate_captures_collection", "Populate captures collection"),
+ ("data_storage_sync", "Data storage sync"),
+ ("unknown", "Unknown"),
+ ("data_export", "Data Export"),
+ ("occurrence_clustering", "Occurrence Feature Clustering"),
+ ],
+ default="unknown",
+ max_length=255,
+ verbose_name="Job Type",
+ ),
+ ),
+ ]
diff --git a/ami/jobs/migrations/0018_alter_job_job_type_key.py b/ami/jobs/migrations/0018_alter_job_job_type_key.py
new file mode 100644
index 000000000..b1a4e664f
--- /dev/null
+++ b/ami/jobs/migrations/0018_alter_job_job_type_key.py
@@ -0,0 +1,29 @@
+# Generated by Django 4.2.10 on 2025-04-28 11:06
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("jobs", "0017_alter_job_job_type_key"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="job",
+ name="job_type_key",
+ field=models.CharField(
+ choices=[
+ ("ml", "ML pipeline"),
+ ("populate_captures_collection", "Populate captures collection"),
+ ("data_storage_sync", "Data storage sync"),
+ ("unknown", "Unknown"),
+ ("data_export", "Data Export"),
+ ("detection_clustering", "Detection Feature Clustering"),
+ ],
+ default="unknown",
+ max_length=255,
+ verbose_name="Job Type",
+ ),
+ ),
+ ]
diff --git a/ami/jobs/models.py b/ami/jobs/models.py
index 1166f2c74..189112661 100644
--- a/ami/jobs/models.py
+++ b/ami/jobs/models.py
@@ -400,7 +400,8 @@ def run(cls, job: "Job"):
total_classifications = 0
config = job.pipeline.get_config(project_id=job.project.pk)
- chunk_size = config.get("request_source_image_batch_size", 1)
+ chunk_size = config.get("request_source_image_batch_size", 2)
+ # @TODO Ensure only images of the same dimensions are processed in a batch
chunks = [images[i : i + chunk_size] for i in range(0, image_count, chunk_size)] # noqa
request_failed_images = []
@@ -639,6 +640,38 @@ def run(cls, job: "Job"):
job.update_status(JobState.SUCCESS, save=True)
+class DetectionClusteringJob(JobType):
+ name = "Detection Feature Clustering"
+ key = "detection_clustering"
+
+ @classmethod
+ def run(cls, job: "Job"):
+ job.update_status(JobState.STARTED)
+ job.started_at = datetime.datetime.now()
+ job.finished_at = None
+ job.progress.add_stage(name="Collecting Features", key="feature_collection")
+ job.progress.add_stage("Clustering", key="clustering")
+ job.progress.add_stage("Creating Unknown Taxa", key="create_unknown_taxa")
+ job.save()
+
+ if not job.source_image_collection:
+ raise ValueError("No source image collection provided")
+
+ job.logger.info(f"Clustering detections for collection {job.source_image_collection}")
+ job.update_status(JobState.STARTED)
+ job.started_at = datetime.datetime.now()
+ job.finished_at = None
+ job.save()
+
+ # Call the clustering method
+ job.source_image_collection.cluster_detections(job=job)
+ job.logger.info(f"Finished clustering detections for collection {job.source_image_collection}")
+
+ job.finished_at = datetime.datetime.now()
+ job.update_status(JobState.SUCCESS, save=False)
+ job.save()
+
+
class UnknownJobType(JobType):
name = "Unknown"
key = "unknown"
@@ -648,7 +681,14 @@ def run(cls, job: "Job"):
raise ValueError(f"Unknown job type '{job.job_type()}'")
-VALID_JOB_TYPES = [MLJob, SourceImageCollectionPopulateJob, DataStorageSyncJob, UnknownJobType, DataExportJob]
+VALID_JOB_TYPES = [
+ MLJob,
+ SourceImageCollectionPopulateJob,
+ DataStorageSyncJob,
+ UnknownJobType,
+ DataExportJob,
+ DetectionClusteringJob,
+]
def get_job_type_by_key(key: str) -> type[JobType] | None:
diff --git a/ami/main/admin.py b/ami/main/admin.py
index dcfa58241..a7b533b98 100644
--- a/ami/main/admin.py
+++ b/ami/main/admin.py
@@ -6,6 +6,7 @@
from django.http.request import HttpRequest
from django.template.defaultfilters import filesizeformat
from django.utils.formats import number_format
+from django.utils.html import format_html
from guardian.admin import GuardedModelAdmin
import ami.utils
@@ -220,7 +221,6 @@ def update_calculated_fields(self, request: HttpRequest, queryset: QuerySet[Even
self.message_user(request, f"Updated {queryset.count()} events.")
list_filter = ("deployment", "project", "start")
- actions = [update_calculated_fields]
@admin.register(SourceImage)
@@ -262,6 +262,7 @@ class ClassificationInline(admin.TabularInline):
model = Classification
extra = 0
fields = (
+ "view_classification",
"taxon",
"algorithm",
"timestamp",
@@ -269,6 +270,7 @@ class ClassificationInline(admin.TabularInline):
"created_at",
)
readonly_fields = (
+ "view_classification",
"taxon",
"algorithm",
"timestamp",
@@ -276,6 +278,11 @@ class ClassificationInline(admin.TabularInline):
"created_at",
)
+ @admin.display(description="Classification")
+ def view_classification(self, obj):
+ url = f"/admin/main/classification/{obj.pk}/change/"
+ return format_html('{}', url, obj.pk)
+
def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
qs = super().get_queryset(request)
return qs.select_related("taxon", "algorithm", "detection")
@@ -285,6 +292,7 @@ class DetectionInline(admin.TabularInline):
model = Detection
extra = 0
fields = (
+ "view_detection",
"detection_algorithm",
"source_image",
"timestamp",
@@ -292,6 +300,7 @@ class DetectionInline(admin.TabularInline):
"occurrence",
)
readonly_fields = (
+ "view_detection",
"detection_algorithm",
"source_image",
"timestamp",
@@ -299,6 +308,11 @@ class DetectionInline(admin.TabularInline):
"occurrence",
)
+ @admin.display(description="Detection")
+ def view_detection(self, obj):
+ url = f"/admin/main/detection/{obj.pk}/change/"
+ return format_html('{}', url, obj.pk)
+
@admin.register(Detection)
class DetectionAdmin(admin.ModelAdmin[Detection]):
@@ -461,7 +475,7 @@ class TaxonAdmin(admin.ModelAdmin[Taxon]):
"created_at",
"updated_at",
)
- list_filter = ("lists", "rank", TaxonParentFilter)
+ list_filter = ("unknown_species", "lists", "rank", TaxonParentFilter)
search_fields = ("name",)
autocomplete_fields = (
"parent",
@@ -594,7 +608,48 @@ def populate_collection_async(self, request: HttpRequest, queryset: QuerySet[Sou
f"Populating {len(queued_tasks)} collection(s) background tasks: {queued_tasks}.",
)
- actions = [populate_collection, populate_collection_async]
+ @admin.action(description="Create clustering job (but don't run it)")
+ @admin.action()
+ def create_clustering_job(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None:
+ from ami.jobs.models import DetectionClusteringJob, Job
+
+ for collection in queryset:
+ job = Job.objects.create(
+ name=f"Clustering detections for collection {collection.pk}",
+ project=collection.project,
+ source_image_collection=collection,
+ job_type_key=DetectionClusteringJob.key,
+ params={
+ "ood_threshold": 0.3,
+ "algorithm": "agglomerative",
+ "algorithm_kwargs": {"distance_threshold": 80},
+ "pca": {"n_components": 384},
+ },
+ )
+ self.message_user(request, f"Created clustering job #{job.pk} for collection #{collection.pk}")
+
+ @admin.action()
+ def cluster_detections(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None:
+ for collection in queryset:
+ from ami.jobs.models import DetectionClusteringJob, Job
+
+ job = Job.objects.create(
+ name=f"Clustering detections for collection {collection.pk}",
+ project=collection.project,
+ source_image_collection=collection,
+ job_type_key=DetectionClusteringJob.key,
+ params={
+ "ood_threshold": 0.3,
+ "algorithm": "agglomerative",
+ "algorithm_kwargs": {"distance_threshold": 80},
+ "pca": {"n_components": 384},
+ },
+ )
+ job.enqueue()
+
+ self.message_user(request, f"Clustered {queryset.count()} collection(s).")
+
+ actions = [populate_collection, populate_collection_async, cluster_detections, create_clustering_job]
# Hide images many-to-many field from form. This would list all source images in the database.
exclude = ("images",)
diff --git a/ami/main/api/serializers.py b/ami/main/api/serializers.py
index 179d42ac0..6ff82cc09 100644
--- a/ami/main/api/serializers.py
+++ b/ami/main/api/serializers.py
@@ -518,6 +518,7 @@ class Meta:
"last_detected",
"best_determination_score",
"cover_image_url",
+ "unknown_species",
"created_at",
"updated_at",
]
@@ -740,6 +741,8 @@ class Meta:
"fieldguide_id",
"cover_image_url",
"cover_image_credit",
+ "unknown_species",
+ "last_detected", # @TODO this has performance impact, review
]
@@ -1548,3 +1551,11 @@ class Meta:
"total_size",
"last_checked",
]
+
+
+class ClusterDetectionsSerializer(serializers.Serializer):
+ ood_threshold = serializers.FloatField(required=False, default=0.0)
+ feature_extraction_algorithm = serializers.CharField(required=False, allow_null=True)
+ algorithm = serializers.CharField(required=False, default="agglomerative")
+ algorithm_kwargs = serializers.DictField(required=False, default={"distance_threshold": 0.5})
+ pca = serializers.DictField(required=False, default={"n_components": 384})
diff --git a/ami/main/api/views.py b/ami/main/api/views.py
index 2d0f6046d..88a06ffd0 100644
--- a/ami/main/api/views.py
+++ b/ami/main/api/views.py
@@ -42,6 +42,8 @@
)
from ami.base.serializers import FilterParamsSerializer, SingleParamSerializer
from ami.base.views import ProjectMixin
+from ami.jobs.models import DetectionClusteringJob, Job
+from ami.main.api.serializers import ClusterDetectionsSerializer
from ami.utils.requests import get_active_classification_threshold, project_id_doc_param
from ami.utils.storages import ConnectionTestResult
@@ -744,6 +746,27 @@ def remove(self, request, pk=None):
}
)
+ @action(detail=True, methods=["post"], name="cluster detections")
+ def cluster_detections(self, request, pk=None):
+ """
+ Trigger a background job to cluster detections from this collection.
+ """
+
+ collection: SourceImageCollection = self.get_object()
+ serializer = ClusterDetectionsSerializer(data=request.data)
+ serializer.is_valid(raise_exception=True)
+ params = serializer.validated_data
+ job = Job.objects.create(
+ name=f"Clustering detections for collection {collection.pk}",
+ project=collection.project,
+ source_image_collection=collection,
+ job_type_key=DetectionClusteringJob.key,
+ params=params,
+ )
+ job.enqueue()
+ logger.info(f"Triggered clustering job for collection {collection.pk}")
+ return Response({"job_id": job.pk, "project_id": collection.project.pk})
+
@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)
@@ -1273,8 +1296,7 @@ def get_queryset(self) -> QuerySet:
project = self.get_active_project()
if project:
- # Allow showing detail views for unobserved taxa
- include_unobserved = True
+ include_unobserved = True # Show detail views for unobserved taxa instead of 404
if self.action == "list":
include_unobserved = self.request.query_params.get("include_unobserved", False)
qs = self.get_taxa_observed(qs, project, include_unobserved=include_unobserved)
diff --git a/ami/main/migrations/0063_taxon_unknown_species.py b/ami/main/migrations/0063_taxon_unknown_species.py
new file mode 100644
index 000000000..2ba04d0b2
--- /dev/null
+++ b/ami/main/migrations/0063_taxon_unknown_species.py
@@ -0,0 +1,17 @@
+# Generated by Django 4.2.10 on 2025-04-28 11:11
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("main", "0062_classification_ood_score_and_more"),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="taxon",
+ name="unknown_species",
+ field=models.BooleanField(default=False, help_text="Is this a clustering-generated taxon"),
+ ),
+ ]
diff --git a/ami/main/migrations/0060_taxon_cover_image_credit_taxon_cover_image_url_and_more.py b/ami/main/migrations/0064_taxon_cover_image_credit_taxon_cover_image_url_and_more.py
similarity index 93%
rename from ami/main/migrations/0060_taxon_cover_image_credit_taxon_cover_image_url_and_more.py
rename to ami/main/migrations/0064_taxon_cover_image_credit_taxon_cover_image_url_and_more.py
index 4c74608f2..ca93c19e5 100644
--- a/ami/main/migrations/0060_taxon_cover_image_credit_taxon_cover_image_url_and_more.py
+++ b/ami/main/migrations/0064_taxon_cover_image_credit_taxon_cover_image_url_and_more.py
@@ -5,7 +5,7 @@
class Migration(migrations.Migration):
dependencies = [
- ("main", "0059_alter_project_options"),
+ ("main", "0063_taxon_unknown_species"),
]
operations = [
diff --git a/ami/main/models.py b/ami/main/models.py
index 150c2407e..b3f8bab49 100644
--- a/ami/main/models.py
+++ b/ami/main/models.py
@@ -30,6 +30,7 @@
from ami.base.fields import DateStringField
from ami.base.models import BaseModel
from ami.main import charts
+from ami.ml.clustering_algorithms.cluster_detections import cluster_detections
from ami.users.models import User
from ami.utils.schemas import OrderedEnum
@@ -2816,7 +2817,7 @@ class Taxon(BaseModel):
authorship_date = models.DateField(null=True, blank=True, help_text="The date the taxon was described.")
ordering = models.IntegerField(null=True, blank=True)
sort_phylogeny = models.BigIntegerField(blank=True, null=True)
-
+ unknown_species = models.BooleanField(default=False, help_text="Is this a clustering-generated taxon")
objects: TaxonManager = TaxonManager()
# Type hints for auto-generated fields
@@ -3237,6 +3238,23 @@ def populate_sample(self, job: "Job | None" = None):
self.save()
task_logger.info(f"Done sampling and saving captures to {self}")
+ def cluster_detections(self, job: "Job | None" = None):
+ if job:
+ task_logger = job.logger
+ params = job.params
+ else:
+ task_logger = logger
+ params = {
+ "algorithm": "agglomerative",
+ "ood_threshold": 0.5,
+ "algorithm_kwargs": {
+ "distance_threshold": 0.5,
+ },
+ "pca": {"n_components": 384},
+ }
+
+ cluster_detections(collection=self, params=params, job=job, task_logger=task_logger)
+
def sample_random(self, size: int = 100):
"""Create a random sample of source images"""
diff --git a/ami/ml/clustering_algorithms/__init__.py b/ami/ml/clustering_algorithms/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/ami/ml/clustering_algorithms/agglomerative.py b/ami/ml/clustering_algorithms/agglomerative.py
new file mode 100644
index 000000000..23801b32b
--- /dev/null
+++ b/ami/ml/clustering_algorithms/agglomerative.py
@@ -0,0 +1,91 @@
+import logging
+import os
+
+import numpy as np
+from scipy.spatial.distance import pdist, squareform
+from sklearn.cluster import AgglomerativeClustering
+
+from .base_clusterer import BaseClusterer
+from .preprocessing_features import dimension_reduction, standardize
+
+logger = logging.getLogger(__name__)
+
+
+def get_distance_threshold(features, labels):
+ distance_matrix = squareform(pdist(features))
+ intra_cluster_distances = []
+ inter_cluster_distances = []
+ for i in range(len(features)):
+ for j in range(i + 1, len(features)):
+ if labels[i] == labels[j]:
+ intra_cluster_distances.append(distance_matrix[i, j])
+ else:
+ inter_cluster_distances.append(distance_matrix[i, j])
+ # choose the 95th percentile of intra-cluster distances
+ threshold = np.percentile(intra_cluster_distances, 95)
+ return threshold
+
+
+class AgglomerativeClusterer(BaseClusterer):
+ def __init__(self, config: dict):
+ self.config = config
+ self.setup_flag = False
+ self.data_dict = None
+ # Access from dictionary instead of attribute
+ self.distance_threshold = config.get("algorithm_kwargs", {}).get("distance_threshold", 0.5)
+ self.n_components = config.get("pca", {}).get("n_components", 384)
+
+ def setup(self, data_dict):
+ # estimate the distance threshold
+ new_data_dict = {}
+ # Get output_dir from dictionary
+ save_dir = self.config.get("output_dir")
+
+ if not self.setup_flag:
+ for data_type in data_dict:
+ new_data_dict[data_type] = {}
+ features = data_dict[data_type]["feat_list"]
+ # Get n_components from dictionary
+ features = dimension_reduction(standardize(features), self.config.get("pca", {}).get("n_components"))
+ labels = data_dict[data_type]["label_list"]
+ new_data_dict[data_type]["feat_list"] = features
+ new_data_dict[data_type]["label_list"] = labels
+
+ np.savez(
+ os.path.join(
+ save_dir,
+ f"{data_type}_processed_pca_{self.config.get('pca', {}).get('n_components')}",
+ ),
+ feat_list=features,
+ label_list=labels,
+ )
+
+ self.data_dict = new_data_dict
+ self.setup_flag = True
+
+ # Auto-calculate threshold if not provided
+ if not self.distance_threshold:
+ self.distance_threshold = get_distance_threshold(
+ data_dict["val"]["feat_list"], data_dict["val"]["label_list"]
+ )
+
+ def cluster(self, features):
+ logger.info(f"distance threshold: {self.distance_threshold}")
+ logger.info("features shape: %s", features.shape)
+ logger.info(f"self.n_components: {self.n_components}")
+ # Get n_components and linkage from dictionary
+ if self.n_components <= min(features.shape[0], features.shape[1]):
+ features = dimension_reduction(standardize(features), self.n_components)
+ else:
+ features = standardize(features)
+ logger.info(f"Skipping PCA, n_components { self.n_components} is larger than features shape ")
+
+ # Get linkage parameter from config
+ linkage = self.config.get("algorithm_kwargs", {}).get("linkage", "ward")
+ logger.info(f" features shape after PCA: {features.shape}")
+
+ clusters = AgglomerativeClustering(
+ n_clusters=None, distance_threshold=self.distance_threshold, linkage=linkage
+ ).fit_predict(features)
+
+ return clusters
diff --git a/ami/ml/clustering_algorithms/base_clusterer.py b/ami/ml/clustering_algorithms/base_clusterer.py
new file mode 100644
index 000000000..586a50d82
--- /dev/null
+++ b/ami/ml/clustering_algorithms/base_clusterer.py
@@ -0,0 +1,43 @@
+import os
+
+import numpy as np
+
+from .preprocessing_features import dimension_reduction, standardize
+
+
+class BaseClusterer:
+ def __init__(self, config):
+ self.config = config
+ self.setup_flag = False
+ self.data_dict = None
+
+ def setup(self, data_dict):
+ new_data_dict = {}
+ save_dir = self.config.output_dir
+ if not self.setup_flag:
+ for data_type in data_dict:
+ new_data_dict[data_type] = {}
+ features = data_dict[data_type]["feat_list"]
+ features = dimension_reduction(standardize(features), self.config.pca.n_components)
+ labels = data_dict[data_type]["label_list"]
+ new_data_dict[data_type]["feat_list"] = features
+ new_data_dict[data_type]["label_list"] = labels
+
+ np.savez(
+ os.path.join(
+ save_dir,
+ f"{data_type}_processed_pca_{self.config.pca.n_components}",
+ ),
+ feat_list=features,
+ label_list=labels,
+ )
+ self.data_dict = new_data_dict
+ self.setup_flag = True
+ else:
+ pass
+
+ def clustering(self, data_dict):
+ pass
+
+ def cluster_detections(self, data_dict):
+ pass
diff --git a/ami/ml/clustering_algorithms/cluster_detections.py b/ami/ml/clustering_algorithms/cluster_detections.py
new file mode 100644
index 000000000..d06f8a194
--- /dev/null
+++ b/ami/ml/clustering_algorithms/cluster_detections.py
@@ -0,0 +1,183 @@
+import logging
+import typing
+
+import numpy as np
+from django.db.models import Count
+from django.utils.timezone import now
+
+from ami.ml.clustering_algorithms.utils import get_clusterer
+
+if typing.TYPE_CHECKING:
+ from ami.main.models import SourceImageCollection
+ from ami.ml.models import Algorithm
+
+logger = logging.getLogger(__name__)
+
+
+def update_job_progress(job, stage_key, status, progress):
+ if job:
+ job.progress.update_stage(stage_key, status=status, progress=progress)
+ job.save()
+
+
+def job_save(job):
+ if job:
+ job.save()
+
+
+def get_most_used_algorithm(
+ collection: "SourceImageCollection", task_logger: logging.Logger | None = None
+) -> "Algorithm | None":
+ from ami.main.models import Classification
+ from ami.ml.models import Algorithm
+
+ task_logger = task_logger or logger
+
+ qs = Classification.objects.filter(
+ features_2048__isnull=False,
+ detection__source_image__collections=collection,
+ algorithm__isnull=False,
+ # @TODO if we have a dedicated task type for feature extraction, we can filter by that
+ # task_type="feature_extraction",
+ )
+
+ # Log the number of classifications per algorithm, if debug is enabled
+ if task_logger.isEnabledFor(logging.DEBUG):
+ algorithm_stats = qs.values("algorithm__pk", "algorithm__name").annotate(count=Count("id")).order_by("-count")
+ task_logger.debug(f"Algorithm stats: {algorithm_stats}")
+
+ feature_extraction_algorithm_id = (
+ qs.values("algorithm")
+ .annotate(count=Count("id"))
+ .order_by("-count")
+ .values_list("algorithm", flat=True)
+ .first()
+ )
+ if feature_extraction_algorithm_id:
+ algorithm = Algorithm.objects.get(pk=feature_extraction_algorithm_id)
+ task_logger.info(f"Using feature extraction algorithm: {algorithm.name}")
+ return algorithm
+ return None
+
+
+def cluster_detections(collection, params: dict, task_logger: logging.Logger = logger, job=None):
+ from ami.jobs.models import JobState
+ from ami.main.models import Classification, Detection, TaxaList, Taxon
+ from ami.ml.models import Algorithm
+ from ami.ml.models.pipeline import create_and_update_occurrences_for_detections
+
+ ood_threshold = params.get("ood_threshold", 1)
+ feature_extraction_algorithm = params.get("feature_extraction_algorithm", None)
+ algorithm = params.get("clustering_algorithm", "agglomerative")
+ task_logger.info(f"Clustering Parameters: {params}")
+ job_save(job)
+ if feature_extraction_algorithm:
+ task_logger.info(f"Feature Extraction Algorithm: {feature_extraction_algorithm}")
+ # Check if the feature extraction algorithm is valid
+ if not Algorithm.objects.filter(key=feature_extraction_algorithm).exists():
+ raise ValueError(f"Invalid feature extraction algorithm key: {feature_extraction_algorithm}")
+ else:
+ # Fallback to the most used feature extraction algorithm in this collection
+ feature_extraction_algorithm = get_most_used_algorithm(collection, task_logger=task_logger)
+
+ detections = Detection.objects.filter(
+ classifications__features_2048__isnull=False,
+ classifications__algorithm=feature_extraction_algorithm,
+ source_image__collections=collection,
+ occurrence__determination_ood_score__gt=ood_threshold,
+ )
+
+ task_logger.info(f"Found {detections.count()} detections to process for clustering")
+
+ features = []
+ valid_detections = []
+ update_job_progress(job, stage_key="feature_collection", status=JobState.STARTED, progress=0.0)
+ # Collecting features for detections
+ for idx, detection in enumerate(detections):
+ classification = detection.classifications.filter(features_2048__isnull=False).first()
+ if classification:
+ features.append(classification.features_2048)
+ valid_detections.append(detection)
+ update_job_progress(
+ job,
+ stage_key="feature_collection",
+ status=JobState.STARTED,
+ progress=(idx + 1) / detections.count(),
+ )
+ update_job_progress(job, stage_key="feature_collection", status=JobState.SUCCESS, progress=1.0)
+ logger.info(f"Clustering {len(features)} features from {len(valid_detections)} detections")
+
+ if not features:
+ raise ValueError("No feature vectors found")
+
+ features_np = np.array(features)
+ task_logger.info(f"Feature vectors shape: {features_np.shape}")
+ logger.info(f"First feature vector: {features_np[0]}, shape: {features_np[0].shape}")
+ update_job_progress(job, stage_key="clustering", status=JobState.STARTED, progress=0.0)
+ # Clustering Detections
+ ClusteringAlgorithm = get_clusterer(algorithm)
+ if not ClusteringAlgorithm:
+ raise ValueError(f"Unsupported clustering algorithm: {algorithm}")
+
+ cluster_ids = ClusteringAlgorithm(params).cluster(features_np)
+
+ task_logger.info(f"Clustering completed with {len(set(cluster_ids))} clusters")
+ clusters = {}
+ for idx, (cluster_id, detection) in enumerate(zip(cluster_ids, valid_detections)):
+ clusters.setdefault(cluster_id, []).append(detection)
+ update_job_progress(
+ job,
+ stage_key="clustering",
+ status=JobState.STARTED,
+ progress=(idx + 1) / len(valid_detections),
+ )
+ update_job_progress(job, stage_key="clustering", status=JobState.SUCCESS, progress=1.0)
+ taxa_list = TaxaList.objects.create(name=f"Clusters from (Job {job.pk if job else 'unknown'})")
+ taxa_list.projects.add(collection.project)
+ taxa_to_add = []
+ clustering_algorithm, _created = Algorithm.objects.get_or_create(
+ name=ClusteringAlgorithm.__name__,
+ task_type="clustering",
+ )
+ logging.info(f"Using clustering algorithm: {clustering_algorithm}")
+ # Creating Unknown Taxa
+ update_job_progress(job, stage_key="create_unknown_taxa", status=JobState.STARTED, progress=0.0)
+ for idx, (cluster_id, cluster_detections) in enumerate(clusters.items()):
+ taxon, _created = Taxon.objects.get_or_create(
+ name=f"Cluster {cluster_id} (Collection {collection.pk}) (Job {job.pk if job else 'unknown'})",
+ rank="SPECIES",
+ notes=f"Auto-created cluster {cluster_id} for collection {collection.pk}",
+ unknown_species=True,
+ )
+ taxon.projects.add(collection.project)
+ taxa_to_add.append(taxon)
+
+ for idx, detection in enumerate(cluster_detections):
+ # Create a new Classification linking the detection to the new taxon
+
+ Classification.objects.create(
+ detection=detection,
+ taxon=taxon,
+ algorithm=clustering_algorithm,
+ score=1.0,
+ timestamp=now(),
+ logits=None,
+ features_2048=None,
+ scores=None,
+ terminal=True,
+ category_map=None,
+ )
+ update_job_progress(
+ job,
+ stage_key="create_unknown_taxa",
+ status=JobState.STARTED,
+ progress=(idx + 1) / len(clusters),
+ )
+ taxa_list.taxa.add(*taxa_to_add)
+ task_logger.info(f"Created {len(clusters)} clusters and updated {len(valid_detections)} detections")
+ update_job_progress(job, stage_key="create_unknown_taxa", status=JobState.SUCCESS, progress=1.0)
+
+ # Updating Occurrences
+ create_and_update_occurrences_for_detections(detections=valid_detections, logger=task_logger)
+ job_save(job)
+ return clusters
diff --git a/ami/ml/clustering_algorithms/clustering_metrics.py b/ami/ml/clustering_algorithms/clustering_metrics.py
new file mode 100644
index 000000000..41ce3da9f
--- /dev/null
+++ b/ami/ml/clustering_algorithms/clustering_metrics.py
@@ -0,0 +1,78 @@
+import numpy as np
+from sklearn.metrics import adjusted_mutual_info_score as ami_score
+from sklearn.metrics import adjusted_rand_score as ari_score
+from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score
+
+from .estimate_k import cluster_acc
+
+
+def pairwise_cost(y_true, y_pred, split_cost=1, merge_cost=2):
+ true_match = y_true[:, None] == y_true
+ pred_match = y_pred[:, None] == y_pred
+
+ split = true_match & ~pred_match # true labels same, but cluster labels different
+ merge = ~true_match & pred_match # true labels different, but cluster labels same
+
+ cost = np.sum(np.triu(split * split_cost | merge * merge_cost, k=1))
+
+ return cost
+
+
+def get_clustering_metrics(labels, preds, old_mask, split_cost, merge_cost):
+ all_acc, ind, _ = cluster_acc(labels.astype(int), preds.astype(int), return_ind=True)
+
+ cluster_mapping = {pair[0]: pair[1] for pair in ind}
+
+ preds = np.array([cluster_mapping[c] for c in preds])
+
+ all_nmi, all_ari, all_ami = (
+ nmi_score(labels, preds),
+ ari_score(labels, preds),
+ ami_score(labels, preds),
+ )
+
+ all_pw_cost = pairwise_cost(labels, preds, split_cost, merge_cost)
+
+ old_preds = preds[old_mask]
+ new_preds = preds[~old_mask]
+
+ old_gt = labels[old_mask]
+ new_gt = labels[~old_mask]
+
+ old_acc, old_nmi, old_ari, old_ami = (
+ cluster_acc(old_gt.astype(int), old_preds.astype(int)),
+ nmi_score(old_gt, old_preds),
+ ari_score(old_gt, old_preds),
+ ami_score(old_gt, old_preds),
+ )
+
+ old_pw_cost = pairwise_cost(old_gt, old_preds, split_cost, merge_cost)
+
+ new_acc, new_nmi, new_ari, new_ami = (
+ cluster_acc(new_gt.astype(int), new_preds.astype(int)),
+ nmi_score(new_gt, new_preds),
+ ari_score(new_gt, new_preds),
+ ami_score(new_gt, new_preds),
+ )
+
+ new_pw_cost = pairwise_cost(new_gt, new_preds, split_cost, merge_cost)
+
+ metrics = {
+ "ACC_all": all_acc,
+ "NMI_all": all_nmi,
+ "ARI_all": all_ari,
+ "AMI_all": all_ami,
+ "pw_cost_all": all_pw_cost,
+ "ACC_old": old_acc,
+ "NMI_old": old_nmi,
+ "ARI_old": old_ari,
+ "AMI_old": old_ami,
+ "pw_cost_old": old_pw_cost,
+ "ACC_new": new_acc,
+ "NMI_new": new_nmi,
+ "ARI_new": new_ari,
+ "AMI_new": new_ami,
+ "pw_cost_new": new_pw_cost,
+ }
+
+ return metrics
diff --git a/ami/ml/clustering_algorithms/preprocessing_features.py b/ami/ml/clustering_algorithms/preprocessing_features.py
new file mode 100644
index 000000000..70fde61d7
--- /dev/null
+++ b/ami/ml/clustering_algorithms/preprocessing_features.py
@@ -0,0 +1,16 @@
+from sklearn import preprocessing
+from sklearn.decomposition import PCA
+
+
+def standardize(features):
+ scaler = preprocessing.StandardScaler().fit(features)
+ features = scaler.transform(features)
+ print("standardized features")
+ return features
+
+
+def dimension_reduction(features, n_components):
+ pca = PCA(n_components=n_components)
+ features = pca.fit_transform(features)
+ print("PCA performed")
+ return features
diff --git a/ami/ml/clustering_algorithms/utils.py b/ami/ml/clustering_algorithms/utils.py
new file mode 100644
index 000000000..03af31a7e
--- /dev/null
+++ b/ami/ml/clustering_algorithms/utils.py
@@ -0,0 +1,16 @@
+from .agglomerative import AgglomerativeClusterer
+
+# from .dbscan import DBSCANClusterer
+# from .kmeans import KMeansClusterer
+# from .mean_shift import MeanShiftClusterer
+
+
+def get_clusterer(clustering_algorithm: str):
+ clusterers = {
+ # "kmeans": KMeansClusterer,
+ "agglomerative": AgglomerativeClusterer,
+ # "mean_shift": MeanShiftClusterer,
+ # "dbscan": DBSCANClusterer,
+ }
+
+ return clusterers.get(clustering_algorithm, None)
diff --git a/ami/ml/migrations/0023_alter_algorithm_task_type.py b/ami/ml/migrations/0023_alter_algorithm_task_type.py
new file mode 100644
index 000000000..2e8fa3975
--- /dev/null
+++ b/ami/ml/migrations/0023_alter_algorithm_task_type.py
@@ -0,0 +1,42 @@
+# Generated by Django 4.2.10 on 2025-05-06 18:10
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("ml", "0022_alter_pipeline_default_config"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="algorithm",
+ name="task_type",
+ field=models.CharField(
+ choices=[
+ ("detection", "Detection"),
+ ("localization", "Localization"),
+ ("segmentation", "Segmentation"),
+ ("classification", "Classification"),
+ ("embedding", "Embedding"),
+ ("tracking", "Tracking"),
+ ("tagging", "Tagging"),
+ ("regression", "Regression"),
+ ("captioning", "Captioning"),
+ ("generation", "Generation"),
+ ("translation", "Translation"),
+ ("summarization", "Summarization"),
+ ("question_answering", "Question Answering"),
+ ("depth_estimation", "Depth Estimation"),
+ ("pose_estimation", "Pose Estimation"),
+ ("size_estimation", "Size Estimation"),
+ ("clustering", "Clustering"),
+ ("other", "Other"),
+ ("unknown", "Unknown"),
+ ],
+ default="unknown",
+ max_length=255,
+ null=True,
+ ),
+ ),
+ ]
diff --git a/ami/ml/models/algorithm.py b/ami/ml/models/algorithm.py
index 82753aac4..5a05d5352 100644
--- a/ami/ml/models/algorithm.py
+++ b/ami/ml/models/algorithm.py
@@ -137,6 +137,7 @@ class Algorithm(BaseModel):
("depth_estimation", "Depth Estimation"),
("pose_estimation", "Pose Estimation"),
("size_estimation", "Size Estimation"),
+ ("clustering", "Clustering"),
("other", "Other"),
("unknown", "Unknown"),
],
diff --git a/ami/ml/tests.py b/ami/ml/tests.py
index 30b32d1a0..eba76a1d4 100644
--- a/ami/ml/tests.py
+++ b/ami/ml/tests.py
@@ -7,7 +7,8 @@
from rest_framework.test import APIRequestFactory, APITestCase
from ami.base.serializers import reverse_with_params
-from ami.main.models import Classification, Detection, Project, SourceImage, SourceImageCollection
+from ami.main.models import Classification, Detection, Occurrence, Project, SourceImage, SourceImageCollection, Taxon
+from ami.ml.clustering_algorithms.cluster_detections import cluster_detections
from ami.ml.models import Algorithm, Pipeline, ProcessingService
from ami.ml.models.pipeline import collect_images, get_or_create_algorithm_and_category_map, save_results
from ami.ml.schemas import (
@@ -19,7 +20,15 @@
PipelineResultsResponse,
SourceImageResponse,
)
-from ami.tests.fixtures.main import create_captures_from_files, create_processing_service, setup_test_project
+from ami.tests.fixtures.main import (
+ create_captures,
+ create_captures_from_files,
+ create_detections,
+ create_processing_service,
+ create_taxa,
+ group_images_into_events,
+ setup_test_project,
+)
from ami.tests.fixtures.ml import ALGORITHM_CHOICES
from ami.users.models import User
@@ -685,3 +694,103 @@ def test_l2_distance(self):
most_similar = qs.first()
self.assertEqual(most_similar.pk, ref_cls.pk, "Most similar classification should be itself")
+
+
+class TestClustering(TestCase):
+ def setUp(self):
+ self.project, self.deployment = setup_test_project()
+ create_taxa(project=self.project)
+ create_captures(deployment=self.deployment, num_nights=2, images_per_night=10, interval_minutes=1)
+ group_images_into_events(deployment=self.deployment)
+
+ sample_size = 10
+ self.collection = SourceImageCollection.objects.create(
+ name="Test Random Source Image Collection",
+ project=self.project,
+ method="random",
+ kwargs={"size": sample_size},
+ )
+ self.collection.save()
+ self.collection.populate_sample()
+ assert self.collection.images.count() == sample_size
+ self.populate_collection_with_detections()
+ self.collection.save()
+ # create_occurrences(deployment=self.deployment)
+
+ self.detections = Detection.objects.filter(source_image__collections=self.collection)
+ self.assertGreater(len(self.detections), 0, "No detections found in the collection")
+ self._populate_detection_features()
+
+ def populate_collection_with_detections(self):
+ """Populate the collection with random detections."""
+ for image in self.collection.images.all():
+ # Create a random detection for each image
+ create_detections(
+ source_image=image, bboxes=[(0.0, 0.0, 1.0, 1.0), (0.1, 0.1, 0.9, 0.9), (0.2, 0.2, 0.8, 0.8)]
+ )
+
+ def _populate_detection_features(self):
+ """Populate detection features with random values."""
+ classifier = Algorithm.objects.get(key="random-species-classifier")
+ for detection in self.detections:
+ detection.associate_new_occurrence()
+ # Create a random feature vector
+ feature_vector = np.random.rand(2048).tolist()
+ # Assign the feature vector to the detection
+ classification = Classification.objects.create(
+ detection=detection,
+ algorithm=classifier,
+ taxon=None,
+ score=0.5,
+ ood_score=0.5,
+ features_2048=feature_vector,
+ timestamp=datetime.datetime.now(),
+ )
+ detection.classifications.add(classification) # type: ignore
+ assert classification.features_2048 is not None, "No features found for the detection"
+ assert detection.occurrence is not None, "No occurrence found for the detection"
+ detection.save()
+ # Call save once on all occurrences
+ for occurrence in Occurrence.objects.filter(detections__in=self.detections).distinct():
+ occurrence.save()
+
+ def test_agglomerative_clustering(self):
+ """Test agglomerative clustering with real implementation."""
+ # Call with agglomerative clustering parameters
+ params = {
+ "algorithm": "agglomerative",
+ "ood_threshold": 0.4,
+ "feature_extraction_algorithm": None, # None will select most used algorithm
+ "agglomerative": {"distance_threshold": 0.5, "linkage": "ward"},
+ "pca": {"n_components": 5}, # Use fewer components for test performance
+ }
+ # Execute the clustering function
+ clusters = cluster_detections(self.collection, params)
+
+ # The exact number could vary based on the random features and threshold
+ self.assertGreaterEqual(len(clusters), 1, "Should create at least 1 cluster")
+
+ # Verify all detections are assigned to clusters
+ total_detections = sum(len(detections) for detections in clusters.values())
+ self.assertEqual(
+ total_detections,
+ len(self.detections),
+ f"All {len(self.detections)} detections should be assigned to clusters",
+ )
+
+ # Check if detections with similar features are in the same cluster
+ # Create a map of detection to cluster_id
+ detection_to_cluster = {}
+ for cluster_id, detections_list in clusters.items():
+ for detection in detections_list:
+ detection_to_cluster[detection.id] = cluster_id
+
+ # Verify that each cluster has a corresponding taxon
+ taxa = Taxon.objects.filter(unknown_species=True)
+ self.assertEqual(taxa.count(), len(clusters), f"Should create {len(clusters)} taxa for the clusters")
+
+ # Verify that each taxon is associated with the project
+ for taxon in taxa:
+ self.assertIn(
+ self.project, taxon.projects.all(), f"Taxon {taxon.name} should be associated with the project"
+ )
diff --git a/ami/tests/fixtures/main.py b/ami/tests/fixtures/main.py
index 812153ff9..31a7a599d 100644
--- a/ami/tests/fixtures/main.py
+++ b/ami/tests/fixtures/main.py
@@ -202,7 +202,9 @@ def create_taxa(project: Project) -> TaxaList:
taxa_list.projects.add(project)
root, _created = Taxon.objects.get_or_create(name="Lepidoptera", rank=TaxonRank.ORDER.name)
root.projects.add(project)
- family_taxon, _ = Taxon.objects.get_or_create(name="Nymphalidae", parent=root, rank=TaxonRank.FAMILY.name)
+ family_taxon, _ = Taxon.objects.get_or_create(
+ name="Nymphalidae", defaults={"parent": root, "rank": TaxonRank.FAMILY.name}
+ )
family_taxon.projects.add(project)
genus_taxon, _ = Taxon.objects.get_or_create(name="Vanessa", parent=family_taxon, rank=TaxonRank.GENUS.name)
genus_taxon.projects.add(project)
diff --git a/docker-compose.yml b/docker-compose.yml
index 20e0462df..09ec5eba2 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -158,5 +158,4 @@ services:
networks:
antenna_network:
- external: true
name: antenna_network
diff --git a/requirements/base.txt b/requirements/base.txt
index 6257f4a2b..ebfbeaa32 100644
--- a/requirements/base.txt
+++ b/requirements/base.txt
@@ -97,3 +97,6 @@ pgvector
newrelic==9.6.0
gunicorn==20.1.0 # https://github.com/benoitc/gunicorn
# psycopg[c]==3.1.9 # https://github.com/psycopg/psycopg
+# ML
+scikit-learn
+scipy