1111from immich_ml .config import log
1212from immich_ml .models .base import InferenceModel
1313from immich_ml .models .transforms import decode_cv2
14- from immich_ml .schemas import ModelSession , ModelTask , ModelType
14+ from immich_ml .schemas import ModelFormat , ModelSession , ModelTask , ModelType
15+ from immich_ml .sessions .ort import OrtSession
1516
1617from .schemas import OcrOptions , TextDetectionOutput
1718
@@ -21,14 +22,14 @@ class TextDetector(InferenceModel):
2122 identity = (ModelType .DETECTION , ModelTask .OCR )
2223
2324 def __init__ (self , model_name : str , ** model_kwargs : Any ) -> None :
24- super ().__init__ (model_name , ** model_kwargs )
25- self .max_resolution = 1440
25+ super ().__init__ (model_name , ** model_kwargs , model_format = ModelFormat . ONNX )
26+ self .max_resolution = 736
2627 self .min_score = 0.5
2728 self .score_mode = "fast"
2829 self ._empty : TextDetectionOutput = {
29- "resized " : np .empty (0 , dtype = np .float32 ),
30+ "image " : np .empty (0 , dtype = np .float32 ),
3031 "boxes" : np .empty (0 , dtype = np .float32 ),
31- "scores" : ( ),
32+ "scores" : np . empty ( 0 , dtype = np . float32 ),
3233 }
3334
3435 def _download (self ) -> None :
@@ -50,7 +51,8 @@ def _download(self) -> None:
5051 DownloadFile .run (download_params )
5152
5253 def _load (self ) -> ModelSession :
53- session = self ._make_session (self .model_path )
54+ # TODO: support other runtime sessions
55+ session = OrtSession (self .model_path )
5456 self .model = RapidTextDetector (
5557 OcrOptions (
5658 session = session .session ,
@@ -62,17 +64,23 @@ def _load(self) -> ModelSession:
6264 )
6365 return session
6466
65- def configure (self , ** kwargs : Any ) -> None :
66- self .max_resolution = kwargs .get ("maxResolution" , self .max_resolution )
67- self .min_score = kwargs .get ("minScore" , self .min_score )
68- self .score_mode = kwargs .get ("scoreMode" , self .score_mode )
69-
70- def _predict (self , inputs : bytes | Image .Image , ** kwargs : Any ) -> TextDetectionOutput :
67+ def _predict (self , inputs : bytes | Image .Image ) -> TextDetectionOutput :
7168 results = self .model (decode_cv2 (inputs ))
7269 if results .boxes is None or results .scores is None or results .img is None :
7370 return self ._empty
7471 return {
75- "resized " : results .img ,
72+ "image " : results .img ,
7673 "boxes" : np .array (results .boxes , dtype = np .float32 ),
7774 "scores" : np .array (results .scores , dtype = np .float32 ),
7875 }
76+
77+ def configure (self , ** kwargs : Any ) -> None :
78+ if (max_resolution := kwargs .get ("maxResolution" )) is not None :
79+ self .max_resolution = max_resolution
80+ self .model .limit_side_len = max_resolution
81+ if (min_score := kwargs .get ("minScore" )) is not None :
82+ self .min_score = min_score
83+ self .model .postprocess_op .box_thresh = min_score
84+ if (score_mode := kwargs .get ("scoreMode" )) is not None :
85+ self .score_mode = score_mode
86+ self .model .postprocess_op .score_mode = score_mode
0 commit comments