33# https://github.com/Megvii-BaseDetection/YOLOX/blob/237e943ac64aa32eb32f875faa93ebb18512d41d/yolox/data/data_augment.py
44# https://github.com/Megvii-BaseDetection/YOLOX/blob/ac379df3c97d1835ebd319afad0c031c36d03f36/yolox/utils/demo_utils.py
55
6+ import os
67from typing import List
78
89import cv2
910import numpy as np
1011import onnxruntime
1112from huggingface_hub import hf_hub_download
13+ from huggingface_hub .constants import HUGGINGFACE_HUB_CACHE
14+ from onnxruntime .quantization import QuantType , quantize_dynamic
1215from PIL import Image
1316
1417from unstructured_inference .inference .layoutelement import LayoutElement
18+ from unstructured_inference .logger import logger
1519from unstructured_inference .models .unstructuredmodel import UnstructuredObjectDetectionModel
1620from unstructured_inference .utils import LazyDict , LazyEvaluateInfo
1721from unstructured_inference .visualize import draw_yolox_bounding_boxes
4751 ),
4852 label_map = YOLOX_LABEL_MAP ,
4953 ),
54+ "yolox_quantized" : {
55+ "model_path" : os .path .join (
56+ HUGGINGFACE_HUB_CACHE ,
57+ "yolox_quantized" ,
58+ "yolox_quantized.onnx" ,
59+ ),
60+ "label_map" : YOLOX_LABEL_MAP ,
61+ },
5062}
5163
5264
@@ -58,6 +70,15 @@ def predict(self, x: Image):
5870
5971 def initialize (self , model_path : str , label_map : dict ):
6072 """Start inference session for YoloX model."""
73+ if not os .path .exists (model_path ) and "yolox_quantized" in model_path :
74+ logger .info ("Quantized model don't currently exists, quantizing now..." )
75+ model_folder = "" .join (os .path .split (model_path )[:- 1 ])
76+ if not os .path .exists (model_folder ):
77+ os .mkdir (model_folder )
78+ source_path = MODEL_TYPES ["yolox" ]["model_path" ]
79+ quantize_dynamic (source_path , model_path , weight_type = QuantType .QUInt8 )
80+ self .model_path = model_path
81+
6182 self .model = onnxruntime .InferenceSession (
6283 model_path ,
6384 providers = [
@@ -66,6 +87,7 @@ def initialize(self, model_path: str, label_map: dict):
6687 "CPUExecutionProvider" ,
6788 ],
6889 )
90+
6991 self .layout_classes = label_map
7092
7193 def image_processing (
@@ -106,7 +128,13 @@ def image_processing(
106128 boxes_xyxy [:, 2 ] = boxes [:, 0 ] + boxes [:, 2 ] / 2.0
107129 boxes_xyxy [:, 3 ] = boxes [:, 1 ] + boxes [:, 3 ] / 2.0
108130 boxes_xyxy /= ratio
109- dets = multiclass_nms (boxes_xyxy , scores , nms_thr = 0.45 , score_thr = 0.1 )
131+
132+ # Note (Benjamin): Distinct models (quantized and original) requires distincts
133+ # levels of thresholds
134+ if "quantized" in self .model_path :
135+ dets = multiclass_nms (boxes_xyxy , scores , nms_thr = 0.0 , score_thr = 0.07 )
136+ else :
137+ dets = multiclass_nms (boxes_xyxy , scores , nms_thr = 0.1 , score_thr = 0.25 )
110138
111139 regions = []
112140
0 commit comments