Skip to content

Commit a0e1390

Browse files
authored
Update export_onnx.py
When saving a picture, restore the picture size. Add output label and box details.
1 parent 46bd4f1 commit a0e1390

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

rtdetr_pytorch/tools/export_onnx.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,18 @@ def forward(self, images, orig_target_sizes):
8282

8383

8484
# import onnxruntime as ort
85-
# from PIL import Image, ImageDraw
85+
# from PIL import Image, ImageDraw, ImageFont
8686
# from torchvision.transforms import ToTensor
87+
# from src.data.coco.coco_dataset import mscoco_category2name, mscoco_category2label, mscoco_label2category
8788

8889
# # print(onnx.helper.printable_graph(mm.graph))
8990

90-
# im = Image.open('./000000014439.jpg').convert('RGB')
91-
# im = im.resize((640, 640))
91+
# # Load the original image without resizing
92+
# original_im = Image.open('./hongkong.jpg').convert('RGB')
93+
# original_size = original_im.size
94+
95+
# # Resize the image for model input
96+
# im = original_im.resize((640, 640))
9297
# im_data = ToTensor()(im)[None]
9398
# print(im_data.shape)
9499

@@ -104,7 +109,7 @@ def forward(self, images, orig_target_sizes):
104109

105110
# labels, boxes, scores = output
106111

107-
# draw = ImageDraw.Draw(im)
112+
# draw = ImageDraw.Draw(original_im) # Draw on the original image
108113
# thrh = 0.6
109114

110115
# for i in range(im_data.shape[0]):
@@ -115,12 +120,17 @@ def forward(self, images, orig_target_sizes):
115120

116121
# print(i, sum(scr > thrh))
117122

118-
# for b in box:
119-
# draw.rectangle(list(b), outline='red',)
120-
# draw.text((b[0], b[1]), text=str(lab[i]), fill='blue', )
121-
122-
# im.save('test.jpg')
123-
123+
# for b, l in zip(box, lab):
124+
# # Scale the bounding boxes back to the original image size
125+
# b = [coord * original_size[j % 2] / 640 for j, coord in enumerate(b)]
126+
# # Get the category name from the label
127+
# category_name = mscoco_category2name[mscoco_label2category[l]]
128+
# draw.rectangle(list(b), outline='red', width=2)
129+
# font = ImageFont.truetype("Arial.ttf", 15)
130+
# draw.text((b[0], b[1]), text=category_name, fill='yellow', font=font)
131+
132+
# # Save the original image with bounding boxes
133+
# original_im.save('test.jpg')
124134

125135

126136
if __name__ == '__main__':

0 commit comments

Comments
 (0)