@@ -23,12 +23,8 @@ def __call__(cls, *args, **kwargs):
2323class Recognizer (metaclass = SingletonMeta ):
2424 def __init__ (self ):
2525 root_dir = os .path .dirname (os .path .dirname (__file__ ))
26- multi_cls_model_path = os .path .join (root_dir , 'captcha_recognizer' , 'models' , 'multi_cls.onnx' )
27- single_cls_model_path = os .path .join (root_dir , 'captcha_recognizer' , 'models' , 'single_cls.onnx' )
28-
29- self .multi_cls_model : cv2 .dnn .Net = cv2 .dnn .readNetFromONNX (multi_cls_model_path )
30-
31- self .single_cls_model : cv2 .dnn .Net = cv2 .dnn .readNetFromONNX (single_cls_model_path )
26+ slider_v1_model_path = os .path .join (root_dir , 'captcha_recognizer' , 'models' , 'slider-v1.onnx' )
27+ self .model_v1 : cv2 .dnn .Net = cv2 .dnn .readNetFromONNX (slider_v1_model_path )
3228
3329 @staticmethod
3430 def image_to_array (source : Union [str , Path , bytes , np .ndarray ] = None ):
@@ -110,26 +106,21 @@ def predict(self, model, source: Union[str, Path, bytes, np.ndarray] = None, con
110106
111107 return detections
112108
113- def identify_gap (self , source , is_single = False , conf = CONF_THRESHOLD , ** kwargs ):
109+ def identify_gap (self , source , conf = CONF_THRESHOLD , ** kwargs ):
114110 """
115111 识别给定图片的缺口。
116112
117113 参数:
118114 - source: 图片源。
119- - is_single: 布尔值,指示是否为单缺口图片。
120115 - conf: 置信度
121116
122117 返回:
123118 - box: 一个列表,包含具有最高置信度的间隙的边界框坐标。
124119 - box_conf: 浮点数,代表间隙的置信度。
125120 """
126- if is_single :
127- model = self .single_cls_model
128- classes = [0 , 1 , 2 ]
129- else :
130- model = self .multi_cls_model
131- classes = [0 ]
132- results = self .predict (model = model , source = source , conf = conf )
121+
122+ classes = [0 ]
123+ results = self .predict (model = self .model_v1 , source = source , conf = conf )
133124 box = []
134125 box_conf = 0
135126 if not len (results ):
@@ -164,7 +155,7 @@ def calculate_difference(slider, box):
164155
165156 def identify_boxes_by_screenshot (self , source : Union [str , Path , bytes , np .ndarray ]):
166157 # 通过截图图片识别所有box
167- results = self .predict (model = self .single_cls_model , source = source )
158+ results = self .predict (model = self .model_v1 , source = source )
168159
169160 box_list = []
170161 if not len (results ):
0 commit comments