33import PIL .Image
44import torch
55import torchvision
6- from torchvision .models .detection .faster_rcnn import FastRCNNPredictor
6+ import torchvision .models .detection .anchor_utils
7+ import torchvision .models .detection .backbone_utils
8+ import torchvision .models .detection .faster_rcnn
9+ import torchvision .models .mobilenetv3
710
811from trapdata import TrapImage , db , logger
912from trapdata .db .models .detections import save_detected_objects
@@ -147,7 +150,7 @@ def save_results(self, item_ids, batch_output):
147150 )
148151
149152
150- class MothObjectDetector_FasterRCNN (ObjectDetector ):
153+ class MothObjectDetector_FasterRCNN_2021 (ObjectDetector ):
151154 name = "FasterRCNN for AMI Moth Traps 2021"
152155 weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/localization/v1_localizmodel_2021-08-17-12-06.pt"
153156 description = (
@@ -160,7 +163,11 @@ def get_model(self):
160163 model = torchvision .models .detection .fasterrcnn_resnet50_fpn (weights = None )
161164 num_classes = 2 # 1 class (object) + background
162165 in_features = model .roi_heads .box_predictor .cls_score .in_features
163- model .roi_heads .box_predictor = FastRCNNPredictor (in_features , num_classes )
166+ model .roi_heads .box_predictor = (
167+ torchvision .models .detection .faster_rcnn .FastRCNNPredictor (
168+ in_features , num_classes
169+ )
170+ )
164171 logger .debug (f"Loading weights: { self .weights } " )
165172 checkpoint = torch .load (self .weights , map_location = self .device )
166173 state_dict = checkpoint .get ("model_state_dict" ) or checkpoint
@@ -186,34 +193,92 @@ def post_process_single(self, output):
186193 return bboxes
187194
188195
189- class GenericObjectDetector_FasterRCNN_MobileNet (ObjectDetector ):
190- name = "Pre-trained FasterRCNN with MobileNet backend"
196+ class MothObjectDetector_FasterRCNN_2023 (ObjectDetector ):
197+ name = "FasterRCNN for AMI Moth Traps 2023"
198+ weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/localization/fasterrcnn_resnet50_fpn_tz53qv9v.pt"
191199 description = (
192- "Faster version of FasterRCNN but not trained on moth trap data . "
193- "Produces multiple overlapping bounding boxes. But helpful for testing on CPU machines ."
200+ "Model trained on GBIF images and synthetic data in 2023 . "
201+ "Accurate but can be slow on a machine without GPU ."
194202 )
195- bbox_score_threshold = 0.01
203+ bbox_score_threshold = 0.80
196204
197205 def get_model (self ):
198- model = torchvision .models .detection .fasterrcnn_mobilenet_v3_large_fpn (
199- weights = "DEFAULT"
206+ num_classes = 2 # 1 class (object) + background
207+ logger .debug (f"Loading weights: { self .weights } " )
208+ model = torchvision .models .get_model (
209+ name = "fasterrcnn_resnet50_fpn" ,
210+ num_classes = num_classes ,
211+ pretrained = False ,
200212 )
201- # @TODO can I use load_state_dict here with weights="DEFAULT"?
213+ checkpoint = torch .load (self .weights , map_location = self .device )
214+ state_dict = checkpoint .get ("model_state_dict" ) or checkpoint
215+ model .load_state_dict (state_dict )
202216 model = model .to (self .device )
203217 model .eval ()
204- return model
218+ self .model = model
219+ return self .model
205220
206221 def post_process_single (self , output ):
207222 # This model does not use the labels from the object detection model
208223 _ = output ["labels" ]
224+ assert all ([label == 1 for label in output ["labels" ]])
209225
210226 # Filter out objects if their score is under score threshold
211- bboxes = output ["boxes" ][
212- (output ["scores" ] > self .bbox_score_threshold ) & (output ["labels" ] > 1 )
213- ]
227+ bboxes = output ["boxes" ][output ["scores" ] > self .bbox_score_threshold ]
214228
215- # Filter out background label, if using pretrained model only!
216- bboxes = output ["boxes" ][output ["labels" ] > 1 ]
229+ logger .debug (
230+ f"Keeping { len (bboxes )} out of { len (output ['boxes' ])} objects found (threshold: { self .bbox_score_threshold } )"
231+ )
232+
233+ bboxes = bboxes .cpu ().numpy ().astype (int ).tolist ()
234+ return bboxes
235+
236+
237+ class MothObjectDetector_FasterRCNN_MobileNet_2023 (ObjectDetector ):
238+ name = "FasterRCNN - MobileNet for AMI Moth Traps 2023"
239+ weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/localization/fasterrcnn_mobilenet_v3_large_fpn_uqfh7u9w.pt"
240+ description = (
241+ "Model trained on GBIF images and synthetic data in 2023. "
242+ "Slightly less accurate but much faster than other models."
243+ )
244+ bbox_score_threshold = 0.50
245+ trainable_backbone_layers = 6 # all layers are trained
246+ anchor_sizes = (64 , 128 , 256 , 512 )
247+ num_classes = 2
248+
249+ def get_model (self ):
250+ norm_layer = torch .nn .BatchNorm2d
251+ backbone = torchvision .models .mobilenetv3 .mobilenet_v3_large (
252+ weights = None , norm_layer = norm_layer
253+ )
254+ backbone = torchvision .models .detection .backbone_utils ._mobilenet_extractor (
255+ backbone , True , self .trainable_backbone_layers
256+ )
257+ anchor_sizes = (self .anchor_sizes ,) * 3
258+ aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
259+ model = torchvision .models .detection .faster_rcnn .FasterRCNN (
260+ backbone ,
261+ self .num_classes ,
262+ rpn_anchor_generator = torchvision .models .detection .anchor_utils .AnchorGenerator (
263+ anchor_sizes , aspect_ratios
264+ ),
265+ rpn_score_thresh = 0.05 ,
266+ )
267+ checkpoint = torch .load (self .weights , map_location = self .device )
268+ state_dict = checkpoint .get ("model_state_dict" ) or checkpoint
269+ model .load_state_dict (state_dict )
270+ model = model .to (self .device )
271+ model .eval ()
272+ self .model = model
273+ return self .model
274+
275+ def post_process_single (self , output ):
276+ # This model does not use the labels from the object detection model
277+ _ = output ["labels" ]
278+ assert all ([label == 1 for label in output ["labels" ]])
279+
280+ # Filter out objects if their score is under score threshold
281+ bboxes = output ["boxes" ][output ["scores" ] > self .bbox_score_threshold ]
217282
218283 logger .debug (
219284 f"Keeping { len (bboxes )} out of { len (output ['boxes' ])} objects found (threshold: { self .bbox_score_threshold } )"
0 commit comments