Using VitDet for inference #4461
-
Hi, I've been trying to use the VitDet Models for inference on custom images but I don't seem to be able to get correct results. For once, every time I spin up the model and I run it on the same image following the same procedure, I get completely different results. Second, upon inspection, the results are also wrong, i.e., none of the reported labels are correct... I'm sure I'm doing something wrong but I just can't figure out what it is. This is what I have tried so far: import torch
import numpy as np
from PIL import Image
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import LazyConfig, instantiate
from detectron2.engine.defaults import create_ddp_model
from detectron2.data import MetadataCatalog
cfg = LazyConfig.load("config/LVIS/mask_rcnn_vitdet_b_100ep.py")
metadata = MetadataCatalog.get(cfg.dataloader.train.dataset.names) # to get labels from ids
classes = metadata.thing_classes
model = instantiate(cfg.model)
model.to(cfg.train.device)
model = create_ddp_model(model)
DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
model.eval()
filename = 'PATH/TO/IMAGE'
image = Image.open(filename)
image = np.array(image, dtype=np.uint8)
image = np.moveaxis(image, -1, 0) # the model expects the image to be in channel first format
with torch.inference_mode():
output = model([{'image': torch.from_numpy(image)}]) I based this snippet of code on the evaluation mode in the VitDet repo and then adjusted it to address the error messages that I encountered. An example output for the image I was using (RGB format which is what the model expects) from one run looks like this: [{'instances': Instances(num_instances=300, image_height=800, image_width=450, fields=[pred_boxes: Boxes(tensor([[448.8199, 549.4664, 450.0000, 595.2785],
[448.8747, 565.3309, 450.0000, 611.2698],
[448.8156, 533.4473, 450.0000, 579.3156],
...,
[393.9430, 468.5813, 438.0978, 559.4935],
[265.4516, 0.0000, 310.6632, 53.5875],
[ 0.0000, 12.6200, 151.6600, 378.0115]], device='cuda:0')), scores: tensor([0.6469, 0.6454, 0.6431, 0.6417, 0.6389, 0.6386, 0.6385, 0.6385, 0.6381,
0.6380, 0.6378, 0.6376, 0.6365, 0.6364, 0.6363, 0.6359, 0.6354, 0.6349,
0.6349, 0.6345, 0.6344, 0.6343, 0.6342, 0.6339, 0.6334, 0.6333, 0.6333,
0.6331, 0.6329, 0.6329, 0.6329, 0.6328, 0.6327, 0.6326, 0.6326, 0.6325,
0.6322, 0.6321, 0.6318, 0.6316, 0.6316, 0.6315, 0.6314, 0.6314, 0.6312,
0.6309, 0.6308, 0.6308, 0.6308, 0.6306, 0.6306, 0.6305, 0.6305, 0.6302,
0.6302, 0.6301, 0.6298, 0.6295, 0.6295, 0.6292, 0.6291, 0.6291, 0.6291,
0.6291, 0.6290, 0.6289, 0.6289, 0.6288, 0.6288, 0.6287, 0.6286, 0.6285,
0.6283, 0.6283, 0.6282, 0.6282, 0.6282, 0.6282, 0.6279, 0.6278, 0.6278,
0.6277, 0.6277, 0.6276, 0.6275, 0.6274, 0.6274, 0.6273, 0.6273, 0.6272,
0.6271, 0.6270, 0.6270, 0.6269, 0.6268, 0.6268, 0.6268, 0.6267, 0.6267,
0.6267, 0.6267, 0.6266, 0.6266, 0.6266, 0.6266, 0.6265, 0.6265, 0.6264,
0.6263, 0.6262, 0.6261, 0.6260, 0.6260, 0.6259, 0.6258, 0.6257, 0.6257,
0.6256, 0.6256, 0.6256, 0.6255, 0.6255, 0.6254, 0.6254, 0.6252, 0.6251,
0.6251, 0.6251, 0.6251, 0.6251, 0.6251, 0.6250, 0.6250, 0.6249, 0.6249,
0.6248, 0.6248, 0.6247, 0.6246, 0.6246, 0.6245, 0.6245, 0.6245, 0.6244,
0.6244, 0.6244, 0.6244, 0.6243, 0.6243, 0.6243, 0.6243, 0.6242, 0.6241,
0.6241, 0.6240, 0.6240, 0.6240, 0.6240, 0.6240, 0.6240, 0.6239, 0.6239,
0.6239, 0.6239, 0.6238, 0.6238, 0.6238, 0.6237, 0.6237, 0.6237, 0.6237,
0.6237, 0.6237, 0.6235, 0.6235, 0.6234, 0.6234, 0.6234, 0.6233, 0.6233,
0.6232, 0.6232, 0.6232, 0.6231, 0.6230, 0.6230, 0.6230, 0.6230, 0.6230,
0.6229, 0.6229, 0.6229, 0.6228, 0.6228, 0.6228, 0.6228, 0.6228, 0.6227,
0.6227, 0.6227, 0.6226, 0.6226, 0.6225, 0.6225, 0.6225, 0.6225, 0.6224,
0.6224, 0.6224, 0.6223, 0.6223, 0.6223, 0.6223, 0.6222, 0.6222, 0.6222,
0.6222, 0.6222, 0.6221, 0.6221, 0.6221, 0.6221, 0.6220, 0.6220, 0.6220,
0.6219, 0.6219, 0.6218, 0.6217, 0.6217, 0.6217, 0.6217, 0.6217, 0.6217,
0.6216, 0.6216, 0.6216, 0.6215, 0.6215, 0.6215, 0.6215, 0.6215, 0.6214,
0.6214, 0.6214, 0.6214, 0.6214, 0.6213, 0.6213, 0.6213, 0.6213, 0.6213,
0.6213, 0.6212, 0.6212, 0.6212, 0.6212, 0.6211, 0.6211, 0.6210, 0.6210,
0.6210, 0.6209, 0.6209, 0.6209, 0.6208, 0.6208, 0.6208, 0.6208, 0.6207,
0.6207, 0.6207, 0.6207, 0.6207, 0.6206, 0.6206, 0.6206, 0.6206, 0.6206,
0.6206, 0.6206, 0.6205, 0.6205, 0.6205, 0.6205, 0.6205, 0.6205, 0.6205,
0.6205, 0.6205, 0.6205, 0.6204, 0.6204, 0.6204, 0.6204, 0.6204, 0.6204,
0.6204, 0.6204, 0.6204], device='cuda:0'), pred_classes: tensor([ 149, 149, 149, 598, 149, 1137, 90, 149, 598, 149, 151, 149,
1137, 149, 149, 149, 149, 532, 1137, 598, 622, 149, 598, 22,
598, 598, 149, 598, 1190, 634, 708, 149, 698, 598, 598, 598,
149, 698, 149, 1190, 598, 149, 149, 1190, 1190, 1137, 149, 1129,
708, 149, 1054, 149, 149, 698, 90, 242, 219, 149, 598, 149,
149, 1190, 149, 598, 149, 532, 634, 149, 149, 1137, 1190, 149,
1190, 1190, 598, 149, 598, 149, 708, 1190, 219, 532, 149, 1054,
1190, 151, 460, 149, 149, 598, 151, 708, 1190, 493, 149, 1190,
634, 149, 493, 708, 149, 698, 219, 219, 698, 708, 1190, 532,
532, 1137, 236, 634, 545, 708, 455, 22, 545, 698, 698, 591,
149, 708, 149, 1054, 598, 421, 149, 708, 149, 708, 149, 22,
460, 947, 540, 947, 708, 1190, 149, 149, 634, 1137, 219, 947,
598, 1054, 1137, 90, 1054, 493, 149, 149, 1190, 1054, 149, 149,
1190, 149, 711, 446, 460, 149, 151, 1190, 219, 1118, 698, 598,
698, 219, 698, 532, 598, 149, 1190, 1190, 698, 598, 598, 149,
149, 598, 1054, 149, 1054, 1190, 149, 1054, 1190, 698, 698, 149,
1190, 149, 460, 1054, 219, 698, 532, 532, 291, 698, 532, 149,
711, 151, 698, 532, 149, 913, 149, 1190, 1054, 1190, 1190, 1190,
1190, 598, 149, 708, 219, 1129, 219, 149, 219, 698, 708, 1190,
149, 1190, 1190, 455, 598, 532, 1056, 1118, 598, 598, 149, 1190,
149, 455, 242, 455, 1054, 598, 151, 149, 1129, 1190, 1190, 532,
532, 698, 598, 634, 532, 1190, 708, 1190, 708, 598, 1190, 149,
768, 1190, 1190, 708, 532, 1190, 537, 598, 698, 698, 151, 149,
1190, 219, 532, 598, 1190, 1056, 149, 1112, 698, 1190, 219, 532,
219, 149, 708, 711, 634, 90, 149, 149, 598, 698, 149, 509],
device='cuda:0'), pred_masks: tensor([[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
...,
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]], device='cuda:0')])}] Then, I inspected the labels by doing: indices = output[0]['instances'].get_fields()['pred_classes'].to("cpu").numpy()
labels = [classes[idx] for idx in indices] # classes = metadata.thing_classes was defined above I also tried to use the Visualizer class to plot results but it would seem as if the bounding boxes were encoded in a different way to what's expected by the class since the plots don't look right either (putting aside the fact that the labels are all wrong). Also I noticed that the scores are always within the range [0.6; 0.65]. That also looks suspicious... One more thing I should mention is that when lazily initialising the model, I get these warnings:
I'm not sure if this is expected but since this is what I get when loading the checkpoint that's available, I assume it must be... I found two issues in the Detectron2 repo which are related to this question, i.e., #4439 and #4415, but they are mostly asking for documentation to be added rather than asking for a specific question. I thought here was a more appropriate place to ask this question. Any help would be appreciated! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
UPDATE: So, after going through the Detectron2 library trying to follow every operation performed during evaluation to check if I was missing anything, I eventually realised that I was using the checkpoint of the pre-trained model used as a starting point for training... 🤦♂️ - for some reason I assumed that the init_checkpoint in the config file corresponded to the trained (final) model... my bad 😅 So downloading the trained model corresponding to the config file you are using and adding the following line to the script should let you load the trained model: cfg.train.init_checkpoint = '<YOUR_PATH>/model_final_5251c5.pkl' That model corresponds to the checkpoint for the ViTDet, ViT-B model using Mask-RCNN for LVIS. The complete snippet of code to run inference on a custom image then would be: import torch
import numpy as np
from PIL import Image
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import LazyConfig, instantiate
from detectron2.engine.defaults import create_ddp_model
from detectron2.data import MetadataCatalog
cfg = LazyConfig.load("config/LVIS/mask_rcnn_vitdet_b_100ep.py")
cfg.train.init_checkpoint = '<YOUR_PATH>/model_final_5251c5.pkl' # replace with the path were you have your model
metadata = MetadataCatalog.get(cfg.dataloader.train.dataset.names) # to get labels from ids
classes = metadata.thing_classes
model = instantiate(cfg.model)
model.to(cfg.train.device)
model = create_ddp_model(model)
DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
model.eval()
filename = 'PATH/TO/IMAGE'
image = Image.open(filename)
image = np.array(image, dtype=np.uint8)
image = np.moveaxis(image, -1, 0) # the model expects the image to be in channel first format
with torch.inference_mode():
output = model([{'image': torch.from_numpy(image)}]) |
Beta Was this translation helpful? Give feedback.
UPDATE: So, after going through the Detectron2 library trying to follow every operation performed during evaluation to check if I was missing anything, I eventually realised that I was using the checkpoint of the pre-trained model used as a starting point for training... 🤦♂️ - for some reason I assumed that the init_checkpoint in the config file corresponded to the trained (final) model... my bad 😅
So downloading the trained model corresponding to the config file you are using and adding the following line to the script should let you load the trained model:
That model corresponds to the checkpoint for the ViTDet, ViT-B mod…