Skip to content

Commit 868d2d2

Browse files
authored
Localization v2 (#52)
* Add Leonard's new localization models * Add back existing model * Fix anchor sizes for mobilenet model * Update new model names & descriptions. * Use new object detector in tests * Update expected test results
1 parent 666fc8a commit 868d2d2

File tree

5 files changed

+568
-201
lines changed

5 files changed

+568
-201
lines changed

trapdata/common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def slugify(s):
2828
# Quick method to make an acceptable attribute name or url part from a title
2929
# install python-slugify for handling unicode chars, numbers at the beginning, etc.
3030
separator = "_"
31-
acceptable_chars = list(string.ascii_letters) + [separator]
31+
acceptable_chars = list(string.ascii_letters) + list(string.digits) + [separator]
3232
return (
3333
"".join(
3434
[

trapdata/ml/models/localization.py

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import PIL.Image
44
import torch
55
import torchvision
6-
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
6+
import torchvision.models.detection.anchor_utils
7+
import torchvision.models.detection.backbone_utils
8+
import torchvision.models.detection.faster_rcnn
9+
import torchvision.models.mobilenetv3
710

811
from trapdata import TrapImage, db, logger
912
from trapdata.db.models.detections import save_detected_objects
@@ -147,7 +150,7 @@ def save_results(self, item_ids, batch_output):
147150
)
148151

149152

150-
class MothObjectDetector_FasterRCNN(ObjectDetector):
153+
class MothObjectDetector_FasterRCNN_2021(ObjectDetector):
151154
name = "FasterRCNN for AMI Moth Traps 2021"
152155
weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/localization/v1_localizmodel_2021-08-17-12-06.pt"
153156
description = (
@@ -160,7 +163,11 @@ def get_model(self):
160163
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)
161164
num_classes = 2 # 1 class (object) + background
162165
in_features = model.roi_heads.box_predictor.cls_score.in_features
163-
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
166+
model.roi_heads.box_predictor = (
167+
torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
168+
in_features, num_classes
169+
)
170+
)
164171
logger.debug(f"Loading weights: {self.weights}")
165172
checkpoint = torch.load(self.weights, map_location=self.device)
166173
state_dict = checkpoint.get("model_state_dict") or checkpoint
@@ -186,34 +193,92 @@ def post_process_single(self, output):
186193
return bboxes
187194

188195

189-
class GenericObjectDetector_FasterRCNN_MobileNet(ObjectDetector):
190-
name = "Pre-trained FasterRCNN with MobileNet backend"
196+
class MothObjectDetector_FasterRCNN_2023(ObjectDetector):
197+
name = "FasterRCNN for AMI Moth Traps 2023"
198+
weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/localization/fasterrcnn_resnet50_fpn_tz53qv9v.pt"
191199
description = (
192-
"Faster version of FasterRCNN but not trained on moth trap data. "
193-
"Produces multiple overlapping bounding boxes. But helpful for testing on CPU machines."
200+
"Model trained on GBIF images and synthetic data in 2023. "
201+
"Accurate but can be slow on a machine without GPU."
194202
)
195-
bbox_score_threshold = 0.01
203+
bbox_score_threshold = 0.80
196204

197205
def get_model(self):
198-
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(
199-
weights="DEFAULT"
206+
num_classes = 2 # 1 class (object) + background
207+
logger.debug(f"Loading weights: {self.weights}")
208+
model = torchvision.models.get_model(
209+
name="fasterrcnn_resnet50_fpn",
210+
num_classes=num_classes,
211+
pretrained=False,
200212
)
201-
# @TODO can I use load_state_dict here with weights="DEFAULT"?
213+
checkpoint = torch.load(self.weights, map_location=self.device)
214+
state_dict = checkpoint.get("model_state_dict") or checkpoint
215+
model.load_state_dict(state_dict)
202216
model = model.to(self.device)
203217
model.eval()
204-
return model
218+
self.model = model
219+
return self.model
205220

206221
def post_process_single(self, output):
207222
# This model does not use the labels from the object detection model
208223
_ = output["labels"]
224+
assert all([label == 1 for label in output["labels"]])
209225

210226
# Filter out objects if their score is under score threshold
211-
bboxes = output["boxes"][
212-
(output["scores"] > self.bbox_score_threshold) & (output["labels"] > 1)
213-
]
227+
bboxes = output["boxes"][output["scores"] > self.bbox_score_threshold]
214228

215-
# Filter out background label, if using pretrained model only!
216-
bboxes = output["boxes"][output["labels"] > 1]
229+
logger.debug(
230+
f"Keeping {len(bboxes)} out of {len(output['boxes'])} objects found (threshold: {self.bbox_score_threshold})"
231+
)
232+
233+
bboxes = bboxes.cpu().numpy().astype(int).tolist()
234+
return bboxes
235+
236+
237+
class MothObjectDetector_FasterRCNN_MobileNet_2023(ObjectDetector):
238+
name = "FasterRCNN - MobileNet for AMI Moth Traps 2023"
239+
weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/localization/fasterrcnn_mobilenet_v3_large_fpn_uqfh7u9w.pt"
240+
description = (
241+
"Model trained on GBIF images and synthetic data in 2023. "
242+
"Slightly less accurate but much faster than other models."
243+
)
244+
bbox_score_threshold = 0.50
245+
trainable_backbone_layers = 6 # all layers are trained
246+
anchor_sizes = (64, 128, 256, 512)
247+
num_classes = 2
248+
249+
def get_model(self):
250+
norm_layer = torch.nn.BatchNorm2d
251+
backbone = torchvision.models.mobilenetv3.mobilenet_v3_large(
252+
weights=None, norm_layer=norm_layer
253+
)
254+
backbone = torchvision.models.detection.backbone_utils._mobilenet_extractor(
255+
backbone, True, self.trainable_backbone_layers
256+
)
257+
anchor_sizes = (self.anchor_sizes,) * 3
258+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
259+
model = torchvision.models.detection.faster_rcnn.FasterRCNN(
260+
backbone,
261+
self.num_classes,
262+
rpn_anchor_generator=torchvision.models.detection.anchor_utils.AnchorGenerator(
263+
anchor_sizes, aspect_ratios
264+
),
265+
rpn_score_thresh=0.05,
266+
)
267+
checkpoint = torch.load(self.weights, map_location=self.device)
268+
state_dict = checkpoint.get("model_state_dict") or checkpoint
269+
model.load_state_dict(state_dict)
270+
model = model.to(self.device)
271+
model.eval()
272+
self.model = model
273+
return self.model
274+
275+
def post_process_single(self, output):
276+
# This model does not use the labels from the object detection model
277+
_ = output["labels"]
278+
assert all([label == 1 for label in output["labels"]])
279+
280+
# Filter out objects if their score is under score threshold
281+
bboxes = output["boxes"][output["scores"] > self.bbox_score_threshold]
217282

218283
logger.debug(
219284
f"Keeping {len(bboxes)} out of {len(output['boxes'])} objects found (threshold: {self.bbox_score_threshold})"

0 commit comments

Comments
 (0)