Skip to content

Commit 0f55d0b

Browse files
authored
Merge pull request #64 from RolnickLab/feat/2024-release
2024 Release updates
2 parents 266a7a4 + 9188978 commit 0f55d0b

37 files changed

+3883
-1418
lines changed

.gitignore

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,13 @@ celerybeat.pid
102102
*.sage.py
103103

104104
# Environments
105-
.env
106-
.venv
107-
env/
108-
venv/
105+
.env*
106+
.venv/
107+
.venv*/
109108
ENV/
110109
env.bak/
111110
venv.bak/
111+
bak/
112112

113113
# Spyder project settings
114114
.spyderproject

README.md

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ Test the whole backend pipeline without the GUI using this command
6060

6161
```sh
6262
python trapdata/tests/test_pipeline.py
63+
# or
64+
ami test pipeline
65+
```
66+
67+
Run all other tests with:
68+
69+
```sh
70+
ami test all
6371
```
6472

6573
## GUI Usage
@@ -149,19 +157,58 @@ A script is available in the repo source to run the commands above.
149157

150158

151159

152-
## KG Notes for adding new models
160+
## Adding new models
153161

154-
- To add new models, save the pt and json files to:
155-
```
156-
~/Library/Application Support/trapdata/models
157-
```
158-
or wherever you set the appropriate dir in settings.
159-
The json file is simply a dict of species name and index.
162+
1) Create a new inference class in `trapdata/ml/models/classification.py` or `trapdata/ml/models/localization.py`. All models inherit from `InferenceBaseClass`, but there are more specific classes for classification and localization and different architectures. Choose the appropriate class to inherit from. It's best to copy an existing inference class that is similar to the new model you are adding.
163+
164+
2) Upload your model weights and category map to a cloud storage service and make sure the file is publicly accessible via a URL. The weights will be downloaded the first time the model is run. Alternatively, you can manually add the model weights to the configured `USER_DATA_PATH` directory under the subdir `USER_DATA_PATH/models/` (on macOS this is `~/Library/Application Support/trapdata/models`). However the model will not be available to other users unless they also manually add the model weights. The category map json file is simply a dict of species names and their indexes in your model's last layer. See the existing category maps for examples.
165+
166+
3) Select your model in the GUI settings or set the `SPECIES_CLASSIFICATION_MODEL` setting. If the model inherits from `SpeciesClassifier` class, it will automatically become one of the valid choices.
160167

161-
Then you need to create a class in `trapdata/ml/models/classification.py` or `trapdata/ml/models/localization.py` and add the model details.
168+
## Clearing the cache & starting fresh
162169

163-
- To clear the cache:
170+
Remove the index of images, all detections and classifications by removing the database file. This will not remove the images themselves, only the metadata about them. The database is located in the user data directory.
164171

172+
On macOS:
165173
```
166-
rm ~/Library/Application\ Support/trapdata/trapdata.db
167-
```
174+
rm ~/Library/Application\ Support/trapdata/trapdata.db
175+
```
176+
177+
On Linux:
178+
```
179+
rm ~/.config/trapdata/trapdata.db
180+
```
181+
182+
On Windows:
183+
```
184+
del %AppData%\trapdata\trapdata.db
185+
```
186+
187+
## Running the web API
188+
189+
The model inference pipeline can be run as a web API using FastAPI. This is what the Antenna platform uses to process images.
190+
191+
To run the API, use the following command:
192+
193+
```sh
194+
ami api
195+
```
196+
197+
View the interactive API docs at http://localhost:2000/
198+
199+
200+
## Web UI demo (Gradio)
201+
202+
A simple web UI is also available to test the inference pipeline. This is a quick way to test models on a remote server via a web browser.
203+
204+
```sh
205+
ami gradio
206+
```
207+
208+
Open http://localhost:7861/
209+
210+
Use ngrok to temporarily expose localhost to the internet:
211+
212+
```sh
213+
ngrok http 7861
214+
```

gradio.conf

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[program:gradio]
2+
directory=/home/ubuntu/ami-data-companion
3+
command=/home/ubuntu/ami-data-companion/.venv/bin/ami gradio
4+
autostart=true
5+
autorestart=true
6+
# stopsignal=KILL
7+
stopasgroup=true
8+
killasgroup=true
9+
stderr_logfile=/var/log/gradio.err.log
10+
stdout_logfile=/var/log/gradio.out.log
11+
# process_name=%(program_name)s_%(process_num)02d

poetry.lock

Lines changed: 2006 additions & 1001 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@ requires = ["poetry-core>=1.0.0"]
1313
build-backend = "poetry.core.masonry.api"
1414

1515
[tool.poetry.dependencies]
16-
python = "^3.9"
16+
python = "^3.10"
1717
pillow = "^9.5.0"
1818
python-dateutil = "^2.8.2"
1919
python-dotenv = "^1.0.0"
20-
pydantic = "^1.10.7"
21-
typer = "^0.7.0"
20+
pydantic = "^2.5.0"
2221
rich = "^13.3.3"
2322
pandas = "^1.5.3"
2423
sqlalchemy = ">2.0"
@@ -27,8 +26,8 @@ alembic = "^1.10.2"
2726
psycopg2-binary = { version = "^2.9.5", optional = true }
2827
sentry-sdk = "^1.18.0"
2928
imagesize = "^1.4.1"
30-
torch = "^2.0.0"
31-
torchvision = "^0.15.1"
29+
torch = "^2.1.0"
30+
torchvision = "^0.16.0"
3231
timm = "^0.6.13"
3332
structlog = "^22.3.0"
3433
kivy = { extras = ["base"], version = "^2.3.0" }
@@ -45,6 +44,14 @@ ipython = "^8.11.0"
4544
pytest-cov = "^4.0.0"
4645
pytest-asyncio = "^0.21.0"
4746
pytest = "*"
47+
numpy = "^1.26.2"
48+
pip = "^23.3.1"
49+
pydantic-settings = "^2.1.0"
50+
boto3 = "^1.33.0"
51+
botocore = "^1.33.0"
52+
mypy-boto3-s3 = "^1.29.7"
53+
typer = "^0.12.3"
54+
gradio = "^4.41.0"
4855

4956

5057
[tool.pytest.ini_options]

scripts/start_db_container.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ set -o errexit
44
set -o nounset
55

66
CONTAINER_NAME=ami-db
7-
HOST_PORT=5432
7+
HOST_PORT=5433
88
POSTGRES_VERSION=14
99
POSTGRES_DB=ami
1010

1111
docker run -d -i --name $CONTAINER_NAME -v "$(pwd)/db_data":/var/lib/postgresql/data --restart always -p $HOST_PORT:5432 -e POSTGRES_HOST_AUTH_METHOD=trust -e POSTGRES_DB=$POSTGRES_DB postgres:$POSTGRES_VERSION
1212

1313
docker logs ami-db --tail 100
1414

15-
echo 'Database started, Connection string: "postgresql://postgres@localhost:5432/ami"'
15+
echo "Database started, Connection string: \"postgresql://postgres@localhost:${HOST_PORT}/${POSTGRES_DB}\""
1616
echo "Stop (and destroy) database with 'docker stop $CONTAINER_NAME' && docker remove $CONTAINER_NAME"

trapdata/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import sentry_sdk
22

3+
from .common import constants, utils
4+
from .common.logs import logger
5+
from .db.models.detections import DetectedObject
6+
from .db.models.events import MonitoringSession
7+
from .db.models.images import TrapImage
8+
39
sentry_sdk.init(
410
dsn="https://d2f65f945fe343669bbd3be5116d5922@o4503927026876416.ingest.sentry.io/4503927029497856",
511
traces_sample_rate=1.0,
612
)
7-
#
13+
814
# import multiprocessing
915

10-
from .common import constants, utils
11-
from .common.logs import logger
12-
from .db.models.detections import DetectedObject
13-
from .db.models.events import MonitoringSession
14-
from .db.models.images import TrapImage
1516

1617
__all__ = [
1718
logger,

trapdata/api/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from trapdata.settings import read_settings
2+
3+
settings = read_settings()

trapdata/api/api.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
"""
2+
Fast API interface for processing images through the localization and classification pipelines.
3+
"""
4+
5+
import enum
6+
import time
7+
8+
import fastapi
9+
import pydantic
10+
from rich import print
11+
12+
from ..common.logs import logger # noqa: F401
13+
from . import settings
14+
from .models.classification import (
15+
APIMothClassifier,
16+
MothClassifierBinary,
17+
MothClassifierGlobal,
18+
MothClassifierPanama,
19+
MothClassifierPanama2024,
20+
MothClassifierQuebecVermont,
21+
MothClassifierTuringAnguilla,
22+
MothClassifierTuringCostaRica,
23+
MothClassifierUKDenmark,
24+
)
25+
from .models.localization import APIMothDetector
26+
from .schemas import Detection, SourceImage
27+
28+
app = fastapi.FastAPI()
29+
30+
31+
class SourceImageRequest(pydantic.BaseModel):
32+
model_config = pydantic.ConfigDict(extra="ignore")
33+
34+
# @TODO bring over new SourceImage & b64 validation from the lepsAI repo
35+
id: str
36+
url: str
37+
# b64: str | None = None
38+
39+
40+
class SourceImageResponse(pydantic.BaseModel):
41+
model_config = pydantic.ConfigDict(extra="ignore")
42+
43+
id: str
44+
url: str
45+
46+
47+
PIPELINE_CHOICES = {
48+
"panama_moths_2023": MothClassifierPanama,
49+
"panama_moths_2024": MothClassifierPanama2024,
50+
"quebec_vermont_moths_2023": MothClassifierQuebecVermont,
51+
"uk_denmark_moths_2023": MothClassifierUKDenmark,
52+
"costa_rica_moths_turing_2024": MothClassifierTuringCostaRica,
53+
"anguilla_moths_turing_2024": MothClassifierTuringAnguilla,
54+
"global_moths_2024": MothClassifierGlobal,
55+
}
56+
_pipeline_choices = dict(zip(PIPELINE_CHOICES.keys(), list(PIPELINE_CHOICES.keys())))
57+
58+
59+
PipelineChoice = enum.Enum("PipelineChoice", _pipeline_choices)
60+
61+
62+
class PipelineRequest(pydantic.BaseModel):
63+
pipeline: PipelineChoice
64+
source_images: list[SourceImageRequest]
65+
66+
67+
class PipelineResponse(pydantic.BaseModel):
68+
pipeline: PipelineChoice
69+
total_time: float
70+
source_images: list[SourceImageResponse]
71+
detections: list[Detection]
72+
73+
74+
@app.get("/")
75+
async def root():
76+
return fastapi.responses.RedirectResponse("/docs")
77+
78+
79+
@app.post("/pipeline/process")
80+
@app.post("/pipeline/process/")
81+
async def process(data: PipelineRequest) -> PipelineResponse:
82+
# Ensure that the source images are unique, filter out duplicates
83+
source_images_index = {
84+
source_image.id: source_image for source_image in data.source_images
85+
}
86+
incoming_source_images = list(source_images_index.values())
87+
if len(incoming_source_images) != len(data.source_images):
88+
logger.warning(
89+
f"Removed {len(data.source_images) - len(incoming_source_images)} duplicate source images"
90+
)
91+
92+
source_image_results = [
93+
SourceImageResponse(**image.model_dump()) for image in incoming_source_images
94+
]
95+
source_images = [
96+
SourceImage(**image.model_dump()) for image in incoming_source_images
97+
]
98+
99+
start_time = time.time()
100+
detector = APIMothDetector(
101+
source_images=source_images,
102+
batch_size=settings.localization_batch_size,
103+
num_workers=settings.num_workers,
104+
# single=True if len(source_images) == 1 else False,
105+
single=True, # @TODO solve issues with reading images in multiprocessing
106+
)
107+
detector_results = detector.run()
108+
num_pre_filter = len(detector_results)
109+
110+
filter = MothClassifierBinary(
111+
source_images=source_images,
112+
detections=detector_results,
113+
batch_size=settings.classification_batch_size,
114+
num_workers=settings.num_workers,
115+
# single=True if len(detector_results) == 1 else False,
116+
single=True, # @TODO solve issues with reading images in multiprocessing
117+
filter_results=False, # Only save results with the positive_binary_label, @TODO make this configurable from request
118+
)
119+
filter.run()
120+
# all_binary_classifications = filter.results
121+
122+
# Compare num detections with num moth detections
123+
num_post_filter = len(filter.results)
124+
logger.info(
125+
f"Binary classifier returned {num_post_filter} out of {num_pre_filter} detections"
126+
)
127+
128+
# Filter results based on positive_binary_label
129+
moth_detections = []
130+
non_moth_detections = []
131+
for detection in filter.results:
132+
for classification in detection.classifications:
133+
if classification.classification == filter.positive_binary_label:
134+
moth_detections.append(detection)
135+
elif classification.classification == filter.negative_binary_label:
136+
non_moth_detections.append(detection)
137+
break
138+
139+
logger.info(
140+
f"Sending {len(moth_detections)} out of {num_pre_filter} detections to the classifier"
141+
)
142+
143+
Classifier = PIPELINE_CHOICES[data.pipeline.value]
144+
classifier: APIMothClassifier = Classifier(
145+
source_images=source_images,
146+
detections=moth_detections,
147+
batch_size=settings.classification_batch_size,
148+
num_workers=settings.num_workers,
149+
# single=True if len(filtered_detections) == 1 else False,
150+
single=True, # @TODO solve issues with reading images in multiprocessing
151+
)
152+
classifier.run()
153+
end_time = time.time()
154+
seconds_elapsed = float(end_time - start_time)
155+
156+
# Return all detections, including those that were not classified as moths
157+
all_detections = classifier.results + non_moth_detections
158+
159+
logger.info(
160+
f"Processed {len(source_images)} images in {seconds_elapsed:.2f} seconds"
161+
)
162+
logger.info(f"Returning {len(all_detections)} detections")
163+
print(all_detections)
164+
165+
# If the number of detections is greater than 100, its suspicious. Log it.
166+
if len(all_detections) > 100:
167+
logger.warning(
168+
f"Detected {len(all_detections)} detections. This is suspicious and may contain duplicates."
169+
)
170+
171+
response = PipelineResponse(
172+
pipeline=data.pipeline,
173+
source_images=source_image_results,
174+
detections=all_detections,
175+
total_time=seconds_elapsed,
176+
)
177+
return response
178+
179+
180+
# Future methods
181+
182+
# batch processing
183+
# async def process_batch(data: PipelineRequest) -> PipelineResponse:
184+
# pass
185+
186+
# render image crops and bboxes on top of the original image
187+
# async def render(data: PipelineRequest) -> PipelineResponse:
188+
# pass
189+
190+
191+
if __name__ == "__main__":
192+
import uvicorn
193+
194+
uvicorn.run(app, host="0.0.0.0", port=2000)

0 commit comments

Comments
 (0)