diff --git a/docker-compose.override.yaml b/docker-compose.override.yaml index 4c60309ae..e1a325ad1 100644 --- a/docker-compose.override.yaml +++ b/docker-compose.override.yaml @@ -117,5 +117,72 @@ services: # Run the npm development script npm run dev " + + iqr_rest: + image: rdwatch-smqtk-iqr + build: + dockerfile: docker/smqtk-iqr.Dockerfile + context: . + # needed for dependencies to properly install + platform: linux/amd64 + command: + ["runApplication", "-a", "IqrService", "-c", "runApp.IqrRestService.json"] + ports: + - 5001:5001 + working_dir: /iqr + volumes: + # where rest + search app configs are stored + - ./iqr:/iqr + + # needed for accessing the workdir from the config + # EDIT THIS VOLUME MOUNT BEFORE RUNNING + #- /path/to/iqr-data/workdir:/iqr/workdir + - /Users/forrest.li/scratch/iqr/workdir:/iqr/workdir + + # needed to make sure faiss_index exists, otherwise IQR initialization won't work + # EDIT THIS VOLUME MOUNT BEFORE RUNNING + #- /path/to/iqr-data/models:/iqr/models + - /Users/forrest.li/scratch/iqr/models:/iqr/models + + # only needed for the built-in IQR interface + #smqtk_mongo: + # image: mongo:latest + # ports: + # - 27017:27017 + # volumes: + # - smqtk-mongo:/data/db + #iqr_web: + # image: rdwatch-smqtk-iqr + # # needed for dependencies to properly install + # platform: linux/amd64 + # command: + # [ + # "runApplication", + # "-a", + # "IqrSearchDispatcher", + # "-c", + # "runApp.IqrSearchApp.json", + # ] + # ports: + # - 5000:5000 + # working_dir: /iqr + # volumes: + # # where rest + search app configs are stored + # # - ./iqr/workdir:/iqr/workdir + # - ./iqr:/iqr + + # # needed for accessing the workdir from the config + # - /SMQTK-IQR/docs/tutorials/tutorial_003_geowatch_descriptors/workdir:/iqr/workdir + + # # needed to make sure faiss_index exists, otherwise IQR initialization won't work + # - /SMQTK-IQR/docs/tutorials/tutorial_003_geowatch_descriptors/models:/iqr/models + + # # The workdir/processed/chips/manifest.json file stores absolute paths, so we add another bind mount to satisfy the abs paths + # # from when I generated the outputs. + # - /SMQTK-IQR/docs/tutorials/tutorial_003_geowatch_descriptors/workdir:/SMQTK-IQR/docs/tutorials/tutorial_003_geowatch_descriptors/workdir + # depends_on: + # - smqtk_mongo + # - iqr_rest + volumes: celery-SAM-model: diff --git a/docker/smqtk-iqr.Dockerfile b/docker/smqtk-iqr.Dockerfile new file mode 100644 index 000000000..99ca825d2 --- /dev/null +++ b/docker/smqtk-iqr.Dockerfile @@ -0,0 +1,32 @@ +FROM python:3.11 + +WORKDIR / +RUN apt update && apt install -y git gdal-bin libgdal-dev +RUN git clone https://github.com/Erotemic/SMQTK-IQR.git +RUN git clone https://github.com/Kitware/SMQTK-Descriptors.git + +WORKDIR /SMQTK-IQR +RUN git checkout dev/add-tutorial-3 +RUN pip install -e . +RUN pip install faiss-cpu==1.8.0 \ + "psycopg2-binary>=2.9.5,<3.0.0" \ + scriptconfig \ + ubelt \ + rich \ + kwcoco \ + opencv-python-headless \ + girder-client \ + # this version matches the one from python:3.11 apt install gdal-bin + gdal==3.6.2 \ + geowatch \ + kwcoco \ + kwgis \ + kwutil \ + scriptconfig \ + # extra pkg for running `geowatch torch_model_stats ...` + netharn + +WORKDIR /SMQTK-Descriptors +RUN pip install -e . + +WORKDIR /SMQTK-IQR diff --git a/docs/IQR.md b/docs/IQR.md new file mode 100644 index 000000000..85dcec1bb --- /dev/null +++ b/docs/IQR.md @@ -0,0 +1,68 @@ +# Iterative Query Refinement + +## Getting Started Locally + +### Initial IQR Data + +To avoid generating the IQR mappings from scratch, you should have a `iqr-data.tar.gz` file with the following contents: + +``` +models/ + faiss_index_params.json + faiss_index +sites/ + *.geojson +workdir/ + data.memorySet.pickle + descriptor_set.pickle + idx2uid.mem_kvstore.pickle + uid2idx.mem_kvstore.pickle +``` + +Extract `iqr-data.tar.gz` to a suitable location. We will refer to it as `/path/to/iqr-data` from here on out. + +### Docker Compose Volumes + +Edit the `iqr_rest` service in `docker-compose.override.yaml`. Specifically, the volumes must be updated with the following mounts. These are commented accordingly in the `docker-compose.override.yaml` file. + +**IMPORTANT**: Replace the `/path/to/iqr-data` path prefix with the correct path. + +- `/path/to/iqr-data/workdir:/iqr/workdir` +- `/path/to/iqr-data/models:/iqr/models` + +### Ingesting The Sites + +First, start the docker services and perform the requisite migrations and setup. + +```bash +docker compose up -d +docker compose run --rm django poetry run django-admin migrate +docker compose run --rm django poetry run django-admin createsuperuser +docker compose run --rm django poetry run django-admin loaddata lookups +``` + +Now, we can ingest the sites provided in the IQR data archive. The following snippet assumes a bash shell currently located in the RD-WATCH repo root. + +**IMPORTANT**: Replace the `/path/to/iqr-data` path prefix with the correct path. + +```bash +for region in "KR_R001" "KR_R002" "CH_R001" "NZ_R001" +do + python ./scripts/loadModelRun.py "$region" "/path/to/iqr-data/sites/${region}_*.geojson" --title "$region" --performer_shortcode TE +done +``` + +### Loading The WorldView Images + +To ensure that the IQR query results have an associated image, open the RD-WATCH interface in the browser and download the "WV" satellite chips for every model run. This may take a long time! + +## Running IQR through RD-WATCH + +1. Navigate to to enable IQR. +1. Select a model run, and then select a site. If the site has IQR enabled, then there will be an IQR button (as shown below). Clicking this button will initiate an IQR query on that site, and a right sidebar will show up with the results. + + + +1. IQR refinement occurs in two steps: + 1. Update positive, neutral, and negative results in the IQR result listing. + 1. Run "Refine Query" to generate a new list of IQR results. diff --git a/docs/images/iqr-button.png b/docs/images/iqr-button.png new file mode 100644 index 000000000..39d8c2375 Binary files /dev/null and b/docs/images/iqr-button.png differ diff --git a/iqr/runApp.IqrRestService.json b/iqr/runApp.IqrRestService.json new file mode 100644 index 000000000..91afd6b16 --- /dev/null +++ b/iqr/runApp.IqrRestService.json @@ -0,0 +1,148 @@ +{ + "flask_app": { + "BASIC_AUTH_PASSWORD": "demo", + "BASIC_AUTH_USERNAME": "demo", + "SECRET_KEY": "MySuperUltraSecret", + "debug_server": true + }, + "server": { + "host": "0.0.0.0", + "port": 5001 + }, + "iqr_service": { + "plugin_notes": { + "classification_factory": "Selection of the backend in which classifications are stored. The in-memory version is recommended because normal caching mechanisms will not account for the variety of classifiers that can potentially be created via this utility.", + "classifier_config": "The configuration to use for training and using classifiers for the /classifier endpoint. When configuring a classifier for use, don't fill out model persistence values as many classifiers may be created and thrown away during this service's operation.", + "descriptor_factory": "What descriptor element factory to use when asked to compute a descriptor on data.", + "descriptor_generator": "Descriptor generation algorithm to use when requested to describe data.", + "descriptor_set": "This is the index from which given positive and negative example descriptors are retrieved from. Not used for nearest neighbor querying. This index must contain all descriptors that could possibly be used as positive/negative examples and updated accordingly.", + "neighbor_index": "This is the neighbor index to pull initial near-positive descriptors from.", + "relevancy_index_config": "The relevancy index config provided should not have persistent storage configured as it will be used in such a way that instances are created, built and destroyed often." + }, + "plugins": { + "classification_factory": { + "smqtk_classifier.impls.classification_element.memory.MemoryClassificationElement": {}, + "type": "smqtk_classifier.impls.classification_element.memory.MemoryClassificationElement" + }, + "classifier_config": { + "smqtk_classifier.impls.classify_descriptor_supervised.sklearn_logistic_regression.SkLearnLogisticRegression": { + }, + "type": "smqtk_classifier.impls.classify_descriptor_supervised.sklearn_logistic_regression.SkLearnLogisticRegression" + }, + "descriptor_factory": { + "smqtk_descriptors.impls.descriptor_element.memory.DescriptorMemoryElement": {}, + "type": "smqtk_descriptors.impls.descriptor_element.memory.DescriptorMemoryElement" + }, + "descriptor_generator": { + "smqtk_descriptors.impls.descriptor_generator.prepopulated.PrePopulatedDescriptorGenerator": { + }, + "type": "smqtk_descriptors.impls.descriptor_generator.prepopulated.PrePopulatedDescriptorGenerator" + }, + "descriptor_set": { + "smqtk_descriptors.impls.descriptor_set.memory.MemoryDescriptorSet": { + "cache_element": { + "smqtk_dataprovider.impls.data_element.file.DataFileElement": { + "explicit_mimetype": null, + "filepath": "workdir/descriptor_set.pickle", + "readonly": false + }, + "type": "smqtk_dataprovider.impls.data_element.file.DataFileElement" + }, + "pickle_protocol": -1 + }, + "type": "smqtk_descriptors.impls.descriptor_set.memory.MemoryDescriptorSet" + }, + "neighbor_index": { + "smqtk_indexing.impls.nn_index.faiss.FaissNearestNeighborsIndex": { + "descriptor_set": { + "smqtk_descriptors.impls.descriptor_set.memory.MemoryDescriptorSet": { + "cache_element": { + "smqtk_dataprovider.impls.data_element.file.DataFileElement": { + "explicit_mimetype": null, + "filepath": "workdir/descriptor_set.pickle", + "readonly": false + }, + "type": "smqtk_dataprovider.impls.data_element.file.DataFileElement" + }, + "pickle_protocol": -1 + }, + "type": "smqtk_descriptors.impls.descriptor_set.memory.MemoryDescriptorSet" + }, + "factory_string": "IDMap,Flat", + "gpu_id": 0, + "idx2uid_kvs": { + "smqtk_dataprovider.impls.key_value_store.memory.MemoryKeyValueStore": { + "cache_element": { + "smqtk_dataprovider.impls.data_element.file.DataFileElement": { + "explicit_mimetype": null, + "filepath": "workdir/idx2uid.mem_kvstore.pickle", + "readonly": false + }, + "type": "smqtk_dataprovider.impls.data_element.file.DataFileElement" + } + }, + "type": "smqtk_dataprovider.impls.key_value_store.memory.MemoryKeyValueStore" + }, + "uid2idx_kvs": { + "smqtk_dataprovider.impls.key_value_store.memory.MemoryKeyValueStore": { + "cache_element": { + "smqtk_dataprovider.impls.data_element.file.DataFileElement": { + "explicit_mimetype": null, + "filepath": "workdir/uid2idx.mem_kvstore.pickle", + "readonly": false + }, + "type": "smqtk_dataprovider.impls.data_element.file.DataFileElement" + } + }, + "type": "smqtk_dataprovider.impls.key_value_store.memory.MemoryKeyValueStore" + }, + "index_element": { + "smqtk_dataprovider.impls.data_element.file.DataFileElement": { + "filepath": "models/faiss_index", + "readonly": false + }, + "type": "smqtk_dataprovider.impls.data_element.file.DataFileElement" + }, + "index_param_element": { + "smqtk_dataprovider.impls.data_element.file.DataFileElement": { + "filepath": "models/faiss_index_params.json", + "readonly": false + }, + "type": "smqtk_dataprovider.impls.data_element.file.DataFileElement" + }, + "ivf_nprobe": 64, + "metric_type": "l2", + "random_seed": 0, + "read_only": false, + "use_gpu": false + }, + "type": "smqtk_indexing.impls.nn_index.faiss.FaissNearestNeighborsIndex" + }, + "rank_relevancy_with_feedback": { + "smqtk_relevancy.impls.rank_relevancy.margin_sampling.RankRelevancyWithMarginSampledFeedback": { + "rank_relevancy": { + "smqtk_relevancy.impls.rank_relevancy.wrap_classifier.RankRelevancyWithSupervisedClassifier": { + "classifier_inst": { + "smqtk_classifier.impls.classify_descriptor_supervised.sklearn_logistic_regression.SkLearnLogisticRegression": { + }, + "type": "smqtk_classifier.impls.classify_descriptor_supervised.sklearn_logistic_regression.SkLearnLogisticRegression" + } + }, + "type": "smqtk_relevancy.impls.rank_relevancy.wrap_classifier.RankRelevancyWithSupervisedClassifier" + }, + "n": 10, + "center": 0.5 + }, + "type": "smqtk_relevancy.impls.rank_relevancy.margin_sampling.RankRelevancyWithMarginSampledFeedback" + } + }, + "session_control": { + "positive_seed_neighbors": 500, + "session_expiration": { + "check_interval_seconds": 30, + "enabled": true, + "session_timeout": 3600 + } + } + } +} diff --git a/iqr/runApp.IqrSearchApp.json b/iqr/runApp.IqrSearchApp.json new file mode 100644 index 000000000..3bcd97f50 --- /dev/null +++ b/iqr/runApp.IqrSearchApp.json @@ -0,0 +1,36 @@ +{ + "flask_app": { + "BASIC_AUTH_PASSWORD": "demo", + "BASIC_AUTH_USERNAME": "demo", + "SECRET_KEY": "MySuperUltraSecret", + "debug": true + }, + "iqr_tabs": { + "GEOWATCH_DEMO": { + "data_set": { + "smqtk_dataprovider.impls.data_set.memory.DataMemorySet": { + "cache_element": { + "smqtk_dataprovider.impls.data_element.file.DataFileElement": { + "explicit_mimetype": null, + "filepath": "workdir/data.memorySet.cache", + "readonly": false + }, + "type": "smqtk_dataprovider.impls.data_element.file.DataFileElement" + }, + "pickle_protocol": -1 + }, + "type": "smqtk_dataprovider.impls.data_set.memory.DataMemorySet" + }, + "iqr_service_url": "iqr_rest:5001", + "working_directory": "workdir" + } + }, + "mongo": { + "database": "smqtk", + "server": "smqtk_mongo:27017" + }, + "server": { + "host": "0.0.0.0", + "port": 5000 + } +} diff --git a/rdwatch/core/api.py b/rdwatch/core/api.py index f16eaaca3..57d6404e0 100644 --- a/rdwatch/core/api.py +++ b/rdwatch/core/api.py @@ -6,6 +6,7 @@ from .views import site from .views.animation import router as animation_router +from .views.iqr import router as iqr_router from .views.model_run import router as model_run_router from .views.performer import router as performer_router from .views.region import router as region_router @@ -27,6 +28,7 @@ api.add_router('/sites/', site.router) api.add_router('/satellite-fetching/', satellite_fetching_router) api.add_router('/animation/', animation_router) +api.add_router('/iqr/', iqr_router) # useful for getting information back about validation errors diff --git a/rdwatch/core/migrations/0042_siteevaluation_smqtk_uuid.py b/rdwatch/core/migrations/0042_siteevaluation_smqtk_uuid.py new file mode 100644 index 000000000..ef6eb4a3d --- /dev/null +++ b/rdwatch/core/migrations/0042_siteevaluation_smqtk_uuid.py @@ -0,0 +1,19 @@ +# Generated by Django 5.0.9 on 2024-10-25 11:20 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ('core', '0041_merge_20241021_1458'), + ] + + operations = [ + migrations.AddField( + model_name='siteevaluation', + name='smqtk_uuid', + field=models.CharField( + blank=True, help_text='SMQTK UUID', max_length=256, null=True + ), + ), + ] diff --git a/rdwatch/core/models/site_evaluation.py b/rdwatch/core/models/site_evaluation.py index e78272800..f93e35b6f 100644 --- a/rdwatch/core/models/site_evaluation.py +++ b/rdwatch/core/models/site_evaluation.py @@ -103,6 +103,13 @@ class Status(models.TextChoices): help_text='Hash of the file for proposals', ) + smqtk_uuid = models.CharField( + max_length=256, + blank=True, + null=True, + help_text='SMQTK UUID', + ) + @property def boundingbox(self) -> tuple[float, float, float, float]: if self.geom: @@ -151,6 +158,9 @@ def bulk_create_from_site_model( point = None geom = None + smqtk_uuid = None + if site_feature.properties.cache: + smqtk_uuid = site_feature.properties.cache.smqtk_uuid if isinstance(site_feature.parsed_geometry, Point): point = site_feature.parsed_geometry else: @@ -170,6 +180,7 @@ def bulk_create_from_site_model( cache_originator_file=cache_originator_file, cache_timestamp=cache_timestamp, cache_commit_hash=cache_commit_hash, + smqtk_uuid=smqtk_uuid, modified_timestamp=datetime.now(), ) SiteObservation.bulk_create_from_site_evaluation(site_eval, site_model) diff --git a/rdwatch/core/schemas/region_model.py b/rdwatch/core/schemas/region_model.py index 1f9fe9f26..e61ce54dc 100644 --- a/rdwatch/core/schemas/region_model.py +++ b/rdwatch/core/schemas/region_model.py @@ -19,6 +19,7 @@ class RegionFeature(Schema): start_date: datetime | None end_date: datetime | None originator: str + iqr_enabled: bool | None # Optional fields comments: str | None diff --git a/rdwatch/core/schemas/site_model.py b/rdwatch/core/schemas/site_model.py index aed30003d..7d41f32aa 100644 --- a/rdwatch/core/schemas/site_model.py +++ b/rdwatch/core/schemas/site_model.py @@ -23,6 +23,9 @@ class SiteFeatureCache(Schema): originator_file: str | None timestamp: datetime | None commit_hash: str | None + # smqtk + video_name: str | None + smqtk_uuid: str | None class SiteFeature(Schema): diff --git a/rdwatch/core/tasks/__init__.py b/rdwatch/core/tasks/__init__.py index 7e91963be..435bedb10 100644 --- a/rdwatch/core/tasks/__init__.py +++ b/rdwatch/core/tasks/__init__.py @@ -967,8 +967,15 @@ def process_model_run_upload(model_run_upload: ModelRunUpload): public=not model_run_upload.private, ) + iqr_enabled = False for site_model in site_models: - SiteEvaluation.bulk_create_from_site_model(site_model, model_run) + site_eval = SiteEvaluation.bulk_create_from_site_model( + site_model, model_run + ) + iqr_enabled = iqr_enabled or site_eval.smqtk_uuid is not None + + if iqr_enabled: + region_model.region_feature.properties.iqr_enabled = iqr_enabled if region_model: SiteEvaluation.bulk_create_from_region_model(region_model, model_run) diff --git a/rdwatch/core/views/iqr.py b/rdwatch/core/views/iqr.py new file mode 100644 index 000000000..778f24f74 --- /dev/null +++ b/rdwatch/core/views/iqr.py @@ -0,0 +1,543 @@ +import bisect +import json +import logging +from collections import defaultdict +from typing import Literal + +import requests +from ninja import Router, Schema + +from django.contrib.gis.db.models.functions import Area, Transform +from django.core.files.storage import default_storage +from django.db import connection +from django.db.models import ( + BooleanField, + Case, + Count, + F, + Field, + Func, + Min, + Q, + When, + Window, +) +from django.http import HttpRequest, HttpResponse +from django.shortcuts import get_object_or_404 + +from rdwatch.core.db.functions import ExtractEpoch, GroupExcludeRowRange +from rdwatch.core.models import SiteEvaluation, SiteImage, SiteObservation + +logger = logging.getLogger(__name__) +router = Router() +session = requests.Session() + +BASE = 'http://iqr_rest:5001' +MAX_RESULTS = 50 + + +class SuccessResponse(Schema): + success: bool + + +class IQRInitializeRequest(Schema): + sid: str | None + init_pos_uuid: str + + +class IQRInitializeResponse(SuccessResponse): + sid: str | None + + +class IQRSessionInfo(Schema): + sid: str + time: dict + uuids_neg: list[str] + uuids_neg_ext: list[str] + uuids_neg_ext_in_model: list[str] + uuids_neg_in_model: list[str] + uuids_pos: list[str] + uuids_pos_ext: list[str] + uuids_pos_ext_in_model: list[str] + uuids_pos_in_model: list[str] + wi_count: int + + +class IQROrderedResultItem(Schema): + pk: str + site_uid: str + site_id: str + image_url: str | None + image_bbox: tuple[float, float, float, float] | None + smqtk_uuid: str + confidence: float + geom: str + geom_extent: list[float] + + +class IQROrderedResults(Schema): + sid: str + # i: int + # j: int + total_results: int + results: list[IQROrderedResultItem] + # results: list[tuple[str, float]] + + +class IQRAdjudicationEntry(Schema): + uuid: str + status: Literal['positive', 'neutral', 'negative'] + + +class IQRAdjudicationRequest(Schema): + adjudications: list[IQRAdjudicationEntry] + + +@router.post('/initialize', response={200: IQRInitializeResponse, 400: SuccessResponse}) +def initialize(request: HttpRequest, init_data: IQRInitializeRequest): + sid = init_data.sid + if not sid: + resp = session.post(f'{BASE}/session') + if resp.status_code != 201: + logger.error('Could not create session (code: %d)', resp.status_code) + return 400, {'success': False} + sid = resp.json()['sid'] + + resp = session.put(f'{BASE}/session', data={'sid': sid}) + if resp.status_code != 200: + logger.error('Could not initialize session (code: %d)', resp.status_code) + return 400, {'success': False} + + resp = session.post( + f'{BASE}/adjudicate', + data={ + 'sid': sid, + 'pos': json.dumps([init_data.init_pos_uuid]), + 'neg': json.dumps([]), + 'neutral': json.dumps([]), + }, + ) + if resp.status_code != 200: + logger.error('Could not adjudicate (code: %d)', resp.status_code) + return 400, {'success': False} + + resp = session.post(f'{BASE}/initialize', data={'sid': sid}) + resp.raise_for_status() + data = resp.json() + return 200, { + 'sid': sid, + 'success': data.get('success', False), + } + + +@router.post('/{sid}/refine', response={200: SuccessResponse, 400: SuccessResponse}) +def refine(request: HttpRequest, sid: str): + resp = session.post(f'{BASE}/refine', data={'sid': sid}) + if resp.status_code != 201: + logger.error('Could not refine session (code: %d)', resp.status_code) + return 400, {'success': False} + return 200, {'success': True} + + +@router.get('/{sid}', response={200: IQRSessionInfo, 400: SuccessResponse}) +def get_session_info(request: HttpRequest, sid: str): + resp = session.get(f'{BASE}/session', params={'sid': sid}) + if resp.status_code != 200: + logger.error('Could not get session info (code: %d)', resp.status_code) + return 400, {'success': False} + return 200, resp.json() + + +def pick_site_image( + images: list[SiteImage], observations: list[SiteObservation] +) -> SiteImage | None: + # ignore observations with no timestamps + observations = [o for o in observations if o.timestamp is not None] + + if not len(images): + return None + if not len(observations): + return images[-1] + + # pick either last active_construction observation, or first post_construction observation + # if no active_construction or post_construction, pick last observation + obs_candidate = observations[-1] + for obs in observations[::-1]: + if obs.label.slug == 'active_construction': + obs_candidate = obs + break + if obs.label.slug == 'post_construction': + obs_candidate = obs + + # find image closest to obs_candidate.timestamp + idx = bisect.bisect_left(images, obs_candidate.timestamp, key=lambda v: v.timestamp) + if idx == len(images): + return images[-1] + return images[idx] + + +@router.get('/{sid}/results', response={200: IQROrderedResults, 400: SuccessResponse}) +def get_ordered_results(request: HttpRequest, sid: str): + resp = session.get(f'{BASE}/get_results', params={'sid': sid}) + if resp.status_code != 200: + logger.error('Could not get session info (code: %d)', resp.status_code) + return 400, {'success': False} + resp_results = resp.json() + uuids: list[str] = [] + confidence_by_uuid = {} + for smqtk_uuid, confidence in resp_results['results']: + uuids.append(smqtk_uuid) + confidence_by_uuid[smqtk_uuid] = confidence + + site_evals = SiteEvaluation.objects.filter(smqtk_uuid__in=uuids) + site_evals = sorted( + [site for site in site_evals], + key=lambda site: -confidence_by_uuid[site.smqtk_uuid], + )[:MAX_RESULTS] + site_ids = [site.id for site in site_evals] + + images_by_site = defaultdict(list) + for site_image in SiteImage.objects.filter(site__in=site_ids).order_by('timestamp'): + images_by_site[site_image.site.id].append(site_image) + + observations_by_site = defaultdict(list) + for obs in SiteObservation.objects.filter(siteeval__in=site_ids).order_by( + 'timestamp' + ): + observations_by_site[obs.siteeval.id].append(obs) + + ordered_results = { + 'sid': sid, + 'total_results': resp_results['total_results'], + 'results': [], + } + for site in site_evals: + site_image = pick_site_image( + images_by_site[site.id], observations_by_site[site.id] + ) + image_url = default_storage.url(site_image.image.name) if site_image else None + image_bbox = ( + site_image.image_bbox.extent + if site_image and site_image.image_bbox + else None + ) + ordered_results['results'].append( + { + 'pk': str(site.id), + 'site_uid': str(site.id), + 'site_id': str(site.site_id), + 'image_url': image_url, + 'image_bbox': image_bbox, + 'smqtk_uuid': site.smqtk_uuid, + 'confidence': confidence_by_uuid[site.smqtk_uuid], + 'geom': str(site.geom), + 'geom_extent': site.geom.transform(4326, clone=True).extent, + } + ) + + ordered_results['results'] = sorted( + ordered_results['results'], key=lambda r: -r['confidence'] + ) + return 200, ordered_results + + +@router.post('/{sid}/adjudicate', response={200: SuccessResponse, 400: SuccessResponse}) +def adjudicate(request: HttpRequest, sid: str, adjudications: IQRAdjudicationRequest): + positives: list[str] = [] + neturals: list[str] = [] + negatives: list[str] = [] + + for entry in adjudications.adjudications: + if entry.status == 'positive': + positives.append(entry.uuid) + elif entry.status == 'neutral': + neturals.append(entry.uuid) + elif entry.status == 'negative': + negatives.append(entry.uuid) + + resp = session.post( + f'{BASE}/adjudicate', + data={ + 'sid': sid, + 'pos': json.dumps(positives), + 'neg': json.dumps(negatives), + 'neutral': json.dumps(neturals), + }, + ) + if resp.status_code != 200: + logger.error('Could not adjudicate (code: %d)', resp.status_code) + return 400, {'success': False} + return 200, {'success': True} + + +@router.get('/site-image-url/{site_id}') +def get_site_image_url(request: HttpRequest, site_id: str): + site = get_object_or_404(SiteEvaluation, id=site_id) + observations = list( + SiteObservation.objects.filter(siteeval=site).order_by('timestamp') + ) + images = list( + SiteImage.objects.filter(site=site, source='WV').order_by('timestamp') + ) + site_image = pick_site_image(images, observations) + return default_storage.url(site_image.image.name) if site_image else None + + +@router.get('/{sid}/vector-tile/{z}/{x}/{y}.pbf/') +def iqr_vector_tile(request: HttpRequest, sid: str, z: int, x: int, y: int): + resp = session.get(f'{BASE}/get_results', params={'sid': sid}) + if resp.status_code != 200: + logger.error('Could not get session info (code: %d)', resp.status_code) + return HttpResponse('Bad Request', status=400) + resp_results = resp.json() + uuids: list[str] = [] + for smqtk_uuid, _ in resp_results['results'][:MAX_RESULTS]: + uuids.append(smqtk_uuid) + + site_evals = SiteEvaluation.objects.filter(smqtk_uuid__in=uuids) + site_ids = [site.id for site in site_evals] + + envelope = Func(z, x, y, function='ST_TileEnvelope') + intersects_geom = Q( + Func( + 'geom', + envelope, + function='ST_Intersects', + output_field=BooleanField(), + ) + ) + intersects_point = Q( + Func( + 'point', + envelope, + function='ST_Intersects', + output_field=BooleanField(), + ) + ) + intersects = intersects_point | intersects_geom + mvtgeom_point = Func( + 'point', + envelope, + function='ST_AsMVTGeom', + output_field=Field(), + ) + mvtgeom = Func( + 'geom', + envelope, + function='ST_AsMVTGeom', + output_field=Field(), + ) + + evaluations_queryset = ( + SiteEvaluation.objects.filter(id__in=site_ids) + .filter(intersects_geom) + .values() + .alias(observation_count=Count('observations')) + .annotate( + id=F('pk'), + uuid=F('pk'), # maintain consistency with scoring DB for clicking on items + mvtgeom=mvtgeom, + configuration_id=F('configuration_id'), + configuration_name=F('configuration__title'), + label=F('label__slug'), + timestamp=ExtractEpoch('timestamp'), + timemin=ExtractEpoch('start_date'), + timemax=ExtractEpoch('end_date'), + performer_id=F('configuration__performer_id'), + performer_name=F('configuration__performer__short_code'), + region=F('configuration__region__name'), + groundtruth=Case( + When( + Q(configuration__performer__short_code='TE') & Q(score=1), + True, + ), + default=False, + ), + site_number=F('number'), + site_polygon=Case( + When( + observation_count=0, + then=True, + ), + default=False, + ), + ) + ) + ( + evaluations_sql, + evaluations_params, + ) = evaluations_queryset.query.sql_with_params() + + evaluations_points_queryset = ( + SiteEvaluation.objects.filter(id__in=site_ids) + .filter(intersects_point) + .values() + .alias(observation_count=Count('observations')) + .annotate( + id=F('pk'), + uuid=F('pk'), # maintain consistency with scoring DB for clicking on items + mvtgeom=mvtgeom_point, + configuration_id=F('configuration_id'), + configuration_name=F('configuration__title'), + label=F('label__slug'), + timestamp=ExtractEpoch('timestamp'), + timemin=ExtractEpoch('start_date'), + timemax=ExtractEpoch('end_date'), + performer_id=F('configuration__performer_id'), + performer_name=F('configuration__performer__short_code'), + region=F('configuration__region__name'), + groundtruth=Case( + When( + Q(configuration__performer__short_code='TE') & Q(score=1), + True, + ), + default=False, + ), + site_number=F('number'), + site_polygon=Case( + When( + observation_count=0, + then=True, + ), + default=False, + ), + ) + ) + ( + evaluations_points_sql, + evaluations_points_params, + ) = evaluations_points_queryset.query.sql_with_params() + + observations_queryset = ( + SiteObservation.objects.filter(siteeval__in=site_ids) + .filter(intersects) + .values() + .annotate( + id=F('pk'), + mvtgeom=mvtgeom, + configuration_id=F('siteeval__configuration_id'), + configuration_name=F('siteeval__configuration__title'), + site_label=F('siteeval__label__slug'), + site_number=F('siteeval__number'), + label=F('label__slug'), + area=Area(Transform('geom', srid=6933)), + timemin=ExtractEpoch('timestamp'), + timemax=ExtractEpoch( + Window( + expression=Min('timestamp'), + partition_by=[F('siteeval')], + frame=GroupExcludeRowRange(start=0, end=None), + order_by='timestamp', # type: ignore + ), + ), + performer_id=F('siteeval__configuration__performer_id'), + performer_name=F('siteeval__configuration__performer__short_code'), + region=F('siteeval__configuration__region__name'), + version=F('siteeval__version'), + groundtruth=Case( + When( + Q(siteeval__configuration__performer__short_code='TE') + & Q(siteeval__score=1), + True, + ), + default=False, + ), + ) + ) + ( + observations_sql, + observations_params, + ) = observations_queryset.query.sql_with_params() + + observations_points_queryset = ( + SiteObservation.objects.filter(siteeval__in=site_ids) + .filter(intersects_point) + .values() + .annotate( + id=F('pk'), + mvtgeom=mvtgeom_point, + configuration_id=F('siteeval__configuration_id'), + configuration_name=F('siteeval__configuration__title'), + site_label=F('siteeval__label__slug'), + site_number=F('siteeval__number'), + label=F('label__slug'), + area=Area(Transform('geom', srid=6933)), + timemin=ExtractEpoch('timestamp'), + timemax=ExtractEpoch( + Window( + expression=Min('timestamp'), + partition_by=[F('siteeval')], + frame=GroupExcludeRowRange(start=0, end=None), + order_by='timestamp', # type: ignore + ), + ), + performer_id=F('siteeval__configuration__performer_id'), + performer_name=F('siteeval__configuration__performer__short_code'), + region=F('siteeval__configuration__region__name'), + version=F('siteeval__version'), + groundtruth=Case( + When( + Q(siteeval__configuration__performer__short_code='TE') + & Q(siteeval__score=1), + True, + ), + default=False, + ), + ) + ) + ( + observations_points_sql, + observations_points_params, + ) = observations_points_queryset.query.sql_with_params() + + sql = f""" + WITH + evaluations AS ({evaluations_sql}), + observations AS ({observations_sql}), + evaluations_points AS ({evaluations_points_sql}), + observations_points AS ({observations_points_sql}) + SELECT ( + ( + SELECT ST_AsMVT(evaluations.*, %s, 4096, 'mvtgeom') + FROM evaluations + ) + || + ( + SELECT ST_AsMVT(observations.*, %s, 4096, 'mvtgeom') + FROM observations + ) + || + ( + SELECT ST_AsMVT(evaluations_points.*, %s, 4096, 'mvtgeom') + FROM evaluations_points + ) + || + ( + SELECT ST_AsMVT(observations_points.*, %s, 4096, 'mvtgeom') + FROM observations_points + ) + ) + """ # noqa: E501 + params = ( + evaluations_params + + observations_params + + evaluations_points_params + + observations_points_params + + ( + f'sites-{sid}', + f'observations-{sid}', + f'sites_points-{sid}', + f'observations_points-{sid}', + ) + ) + + with connection.cursor() as cursor: + cursor.execute(sql, params) + row = cursor.fetchone() + tile = row[0] + + return HttpResponse( + tile, + content_type='application/octet-stream', + status=200 if tile else 204, + ) diff --git a/rdwatch/core/views/model_run.py b/rdwatch/core/views/model_run.py index 6100ace51..29ce6a557 100644 --- a/rdwatch/core/views/model_run.py +++ b/rdwatch/core/views/model_run.py @@ -579,6 +579,7 @@ def get_sites_query(model_run_id: UUID4): filename='cache_originator_file', downloading='downloading', groundtruth=F('configuration__ground_truth'), + smqtk_uuid='smqtk_uuid', ), ordering='number', default=[], diff --git a/vue/src/actions/map.ts b/vue/src/actions/map.ts index 84a8c9bac..a19c4a0a8 100644 --- a/vue/src/actions/map.ts +++ b/vue/src/actions/map.ts @@ -1,3 +1,4 @@ +import { type FitBoundsOptions } from 'maplibre-gl'; import { type BoundingBox, createEventHook } from '../utils'; -export const FitBoundsEvent = createEventHook(); +export const FitBoundsEvent = createEventHook(); diff --git a/vue/src/client/services/ApiService.ts b/vue/src/client/services/ApiService.ts index 1aa323f3f..47ccd7cbf 100644 --- a/vue/src/client/services/ApiService.ts +++ b/vue/src/client/services/ApiService.ts @@ -68,6 +68,7 @@ export interface SiteInfo { downloading: boolean; groundtruth?: boolean; originator?: string; + smqtk_uuid?: string | null; } export interface SiteList { region: Region; @@ -250,6 +251,45 @@ export interface SatelliteFetchingDownloadingInfo { siteEvalId: string; } +export interface IQRInitializeResponse { + sid: string; + success: boolean; +} + +export interface IQRSessionInfo { + success: boolean; + sid: string; + time: object; + uuids_neg: string[]; + uuids_neg_ext: string[]; + uuids_neg_ext_in_model: string[]; + uuids_neg_in_model: string[]; + uuids_pos: string[]; + uuids_pos_ext: string[]; + uuids_pos_ext_in_model: string[]; + uuids_pos_in_model: string[]; + wi_count: number; +} + +export interface IQROrderedResultItem { + pk: string; + site_uid: string; + site_id: string; + image_url: string | null; + image_bbox: [number, number, number, number] | null; + smqtk_uuid: string; + confidence: number; + geom: string; + geom_extent: number[]; +} + +export interface IQROrderedResults { + i: number; + j: number; + sid: string; + total_results: number; + results: Array; +} type ApiPrefix = '/api' | '/api/scoring'; @@ -528,6 +568,68 @@ export class ApiService { }); } + public static iqrInitialize(sessionId: string | null, posUuid: string): CancelablePromise { + return __request(OpenAPI, { + method: 'POST', + url: `${this.getApiPrefix()}/iqr/initialize`, + body: { + sid: sessionId, + init_pos_uuid: posUuid, + }, + }); + } + + public static iqrRefine(sessionId: string): CancelablePromise<{ success: boolean }> { + return __request(OpenAPI, { + method: 'POST', + url: `${this.getApiPrefix()}/iqr/{sid}/refine`, + path: { + sid: sessionId, + }, + }); + } + + public static iqrAdjudicate(sessionId: string, adjudications: Array<{ uuid: string, status: 'positive' | 'neutral' | 'negative' }>): CancelablePromise<{ success: boolean }> { + return __request(OpenAPI, { + method: 'POST', + url: `${this.getApiPrefix()}/iqr/{sid}/adjudicate`, + path: { + sid: sessionId, + }, + body: { adjudications }, + }); + } + + public static iqrGetSessionInfo(sessionId: string): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: `${this.getApiPrefix()}/iqr/{sid}`, + path: { + sid: sessionId, + }, + }); + } + + public static iqrGetOrderedResults(sessionId: string): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: `${this.getApiPrefix()}/iqr/{sid}/results`, + path: { + sid: sessionId, + }, + }); + } + + public static iqrGetSiteImageUrl(siteId: string): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: `${this.getApiPrefix()}/iqr/site-image-url/{siteId}`, + path: { + siteId, + }, + }); + } + public static getSatelliteTimestamps( constellation="S2", spectrum="visual", diff --git a/vue/src/components/MapLibre.vue b/vue/src/components/MapLibre.vue index ad2302a2f..50735ab36 100644 --- a/vue/src/components/MapLibre.vue +++ b/vue/src/components/MapLibre.vue @@ -7,7 +7,7 @@ import { } from "../mapstyle/rdwatchtiles"; import { filteredSatelliteTimeList, state } from "../store"; import { computed, markRaw, onBeforeUnmount, onMounted, onUnmounted, reactive, shallowRef, watch } from "vue"; -import type { FilterSpecification } from "maplibre-gl"; +import type { FilterSpecification, FitBoundsOptions } from "maplibre-gl"; import type { ShallowRef } from "vue"; import { popupLogic, setPopupEvents } from "../interactions/mouseEvents"; import useEditPolygon from "../interactions/editGeoJSON"; @@ -17,7 +17,10 @@ import { setSatelliteTimeStamp } from "../mapstyle/satellite-image"; import { isEqual, throttle } from 'lodash'; import { updateImageMapSources } from "../mapstyle/images"; import { FitBoundsEvent } from "../actions/map"; -import { type BoundingBox, getGeoJSONBounds } from "../utils"; +import { type BoundingBox, getGeoJSONBounds, timeoutBatch } from "../utils"; +import { useIQR } from "../use/useIQR"; +import { IQROrderedResultItem } from "../client/services/ApiService"; +import { useResizeObserver } from '../use/useResizeObserver'; const mapContainer: ShallowRef = shallowRef(null); const map: ShallowRef = shallowRef(null); @@ -28,22 +31,34 @@ const localGeoJSONFeatures = computed(() => { return state.localMapFeatureIds.map((id) => state.localMapFeatureById[id]); }); +const batchedResize = timeoutBatch(() => { map.value?.resize(); }, 50); + +useResizeObserver(mapContainer, () => { + batchedResize(); +}); + +const { state: iqrState } = useIQR(); +const iqrResults = computed(() => iqrState.results as IQROrderedResultItem[]); +const iqrSessionId = computed(() => iqrState.sessionId); + function setFilter(layerID: string, filter: FilterSpecification) { map.value?.setFilter(layerID, filter, { validate: false, }); } -function fitBounds(bbox: BoundingBox) { +function fitBounds(bbox: BoundingBox & FitBoundsOptions) { + const { xmin, ymin, xmax, ymax, ...options } = bbox; map.value?.fitBounds( [ - [bbox.xmin, bbox.ymin], - [bbox.xmax, bbox.ymax], + [xmin, ymin], + [xmax, ymax], ], { padding: 160, - duration: 5000, - } + duration: 4000, + ...options, + }, ); } @@ -74,6 +89,8 @@ onMounted(() => { Array.from(modelRunVectorLayers), regionIds, localGeoJSONFeatures.value, + iqrResults.value, + iqrSessionId.value, ), bounds: [ [-180, -90], @@ -106,7 +123,7 @@ const throttledSetSatelliteTimeStamp = throttle(setSatelliteTimeStamp, 300); watch([() => state.timestamp, () => state.filters, () => state.satellite, () => state.filters.scoringColoring, () => state.satellite.satelliteSources, () => state.enabledSiteImages, () => state.filters.hoverSiteId, () => state.modelRuns, () => state.openedModelRuns, () => state.filters.proposals, () => state.filters.randomKey, () => state.filters.editingGeoJSONSiteId, -localGeoJSONFeatures], (newVals, oldVals) => { +localGeoJSONFeatures, iqrResults, iqrSessionId], (newVals, oldVals) => { if (state.satellite.satelliteImagesOn) { throttledSetSatelliteTimeStamp(state, filteredSatelliteTimeList.value); @@ -137,7 +154,7 @@ localGeoJSONFeatures], (newVals, oldVals) => { updateImageMapSources(state.timestamp, state.enabledSiteImages, state.siteOverviewSatSettings, map.value ) } map.value?.setStyle( - style(state.timestamp, state.filters, state.satellite, state.enabledSiteImages, state.siteOverviewSatSettings, Array.from(modelRunVectorLayers), regionIds, localGeoJSONFeatures.value, state.filters.randomKey), + style(state.timestamp, state.filters, state.satellite, state.enabledSiteImages, state.siteOverviewSatSettings, Array.from(modelRunVectorLayers), regionIds, localGeoJSONFeatures.value, iqrResults.value, iqrSessionId.value, state.filters.randomKey), ); const siteFilter = buildSiteFilter(state.timestamp, state.filters); diff --git a/vue/src/components/filters/RegionFilter.vue b/vue/src/components/filters/RegionFilter.vue index 86818c548..cc5b96fc0 100755 --- a/vue/src/components/filters/RegionFilter.vue +++ b/vue/src/components/filters/RegionFilter.vue @@ -45,6 +45,9 @@ watch(selectedRegion, (val) => { if (router.currentRoute.value.fullPath.includes('proposals')) { prepend += 'proposals/' } + if (router.currentRoute.value.fullPath.startsWith('/iqr')) { + prepend += 'iqr/' + } if (val) { router.push(`${prepend}${val}`) emit("update:modelValue", val); diff --git a/vue/src/components/imageViewer/ImageViewer.vue b/vue/src/components/imageViewer/ImageViewer.vue index 508856be8..7859a75cf 100644 --- a/vue/src/components/imageViewer/ImageViewer.vue +++ b/vue/src/components/imageViewer/ImageViewer.vue @@ -325,6 +325,8 @@ const close = () => { clearInterval(embeddingCheckInterval.value[key]); }); state.selectedImageSite = null; + + emit('close'); } diff --git a/vue/src/components/iqr/IqrCandidate.vue b/vue/src/components/iqr/IqrCandidate.vue new file mode 100644 index 000000000..f8f2acbfb --- /dev/null +++ b/vue/src/components/iqr/IqrCandidate.vue @@ -0,0 +1,97 @@ + + + diff --git a/vue/src/components/iqr/IqrSimilarSites.vue b/vue/src/components/iqr/IqrSimilarSites.vue new file mode 100644 index 000000000..a3491c1b0 --- /dev/null +++ b/vue/src/components/iqr/IqrSimilarSites.vue @@ -0,0 +1,106 @@ + + + diff --git a/vue/src/components/siteList/SiteList.vue b/vue/src/components/siteList/SiteList.vue index 4f6e09b43..f280d99e5 100644 --- a/vue/src/components/siteList/SiteList.vue +++ b/vue/src/components/siteList/SiteList.vue @@ -124,6 +124,7 @@ const getSites = async (modelRun: string, initRun = false) => { downloading: item.downloading, details, proposal: !!details?.proposal, + smqtkUuid: item.smqtk_uuid, }); modelRunTitleList.value.push(details?.title || ''); totalCount.value += 1; diff --git a/vue/src/components/siteList/SiteListCard.vue b/vue/src/components/siteList/SiteListCard.vue index 3e91e6305..070e850d8 100644 --- a/vue/src/components/siteList/SiteListCard.vue +++ b/vue/src/components/siteList/SiteListCard.vue @@ -11,6 +11,7 @@ import { hoveredInfo } from "../../interactions/mouseEvents"; import ImageBrowser from './ImageBrowser.vue'; import ImageToggle from './ImageToggle.vue'; import AnimationDownloadDialog from "../animation/AnimationDownloadDialog.vue"; +import { useIQR } from "../../use/useIQR"; export interface SiteDisplay { @@ -44,6 +45,7 @@ export interface SiteDisplay { title: string; }; selectedSite?: SiteOverview; + smqtkUuid?: string | null; } const props = defineProps<{ @@ -110,6 +112,18 @@ const selectingSite = async (e: boolean) => { const animationDialog = ref(false); +const iqr = useIQR(); + +const runIQR = async (site: SiteDisplay) => { + if (!site.smqtkUuid) return; + iqr.setPrimarySite({ + name: site.name, + id: site.id, + smqtkUuid: site.smqtkUuid, + modelRunId: site.modelRunId, + }); + await iqr.initializeSession(); +};