Skip to content

Commit 585d093

Browse files
committed
apply config correctly, adjust defaults
1 parent 22690fa commit 585d093

File tree

10 files changed

+43
-35
lines changed

10 files changed

+43
-35
lines changed

machine-learning/immich_ml/models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def predict(self, *inputs: Any, **model_kwargs: Any) -> Any:
5757
self.load()
5858
if model_kwargs:
5959
self.configure(**model_kwargs)
60-
return self._predict(*inputs, **model_kwargs)
60+
return self._predict(*inputs)
6161

6262
@abstractmethod
6363
def _predict(self, *inputs: Any, **model_kwargs: Any) -> Any: ...

machine-learning/immich_ml/models/clip/textual.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
1919
depends = []
2020
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)
2121

22-
def _predict(self, inputs: str, language: str | None = None, **kwargs: Any) -> str:
22+
def _predict(self, inputs: str, language: str | None = None) -> str:
2323
tokens = self.tokenize(inputs, language=language)
2424
res: NDArray[np.float32] = self.session.run(None, tokens)[0][0]
2525
return serialize_np_array(res)

machine-learning/immich_ml/models/clip/visual.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class BaseCLIPVisualEncoder(InferenceModel):
2626
depends = []
2727
identity = (ModelType.VISUAL, ModelTask.SEARCH)
2828

29-
def _predict(self, inputs: Image.Image | bytes, **kwargs: Any) -> str:
29+
def _predict(self, inputs: Image.Image | bytes) -> str:
3030
image = decode_pil(inputs)
3131
res: NDArray[np.float32] = self.session.run(None, self.transform(image))[0][0]
3232
return serialize_np_array(res)

machine-learning/immich_ml/models/facial_recognition/detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _load(self) -> ModelSession:
2424

2525
return session
2626

27-
def _predict(self, inputs: NDArray[np.uint8] | bytes, **kwargs: Any) -> FaceDetectionOutput:
27+
def _predict(self, inputs: NDArray[np.uint8] | bytes) -> FaceDetectionOutput:
2828
inputs = decode_cv2(inputs)
2929

3030
bboxes, landmarks = self._detect(inputs)

machine-learning/immich_ml/models/facial_recognition/recognition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _load(self) -> ModelSession:
4444
return session
4545

4646
def _predict(
47-
self, inputs: NDArray[np.uint8] | bytes | Image.Image, faces: FaceDetectionOutput, **kwargs: Any
47+
self, inputs: NDArray[np.uint8] | bytes | Image.Image, faces: FaceDetectionOutput
4848
) -> FacialRecognitionOutput:
4949
if faces["boxes"].shape[0] == 0:
5050
return []

machine-learning/immich_ml/models/ocr/detection.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from immich_ml.config import log
1212
from immich_ml.models.base import InferenceModel
1313
from immich_ml.models.transforms import decode_cv2
14-
from immich_ml.schemas import ModelSession, ModelTask, ModelType
14+
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
15+
from immich_ml.sessions.ort import OrtSession
1516

1617
from .schemas import OcrOptions, TextDetectionOutput
1718

@@ -21,14 +22,14 @@ class TextDetector(InferenceModel):
2122
identity = (ModelType.DETECTION, ModelTask.OCR)
2223

2324
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
24-
super().__init__(model_name, **model_kwargs)
25-
self.max_resolution = 1440
25+
super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX)
26+
self.max_resolution = 736
2627
self.min_score = 0.5
2728
self.score_mode = "fast"
2829
self._empty: TextDetectionOutput = {
29-
"resized": np.empty(0, dtype=np.float32),
30+
"image": np.empty(0, dtype=np.float32),
3031
"boxes": np.empty(0, dtype=np.float32),
31-
"scores": (),
32+
"scores": np.empty(0, dtype=np.float32),
3233
}
3334

3435
def _download(self) -> None:
@@ -50,7 +51,8 @@ def _download(self) -> None:
5051
DownloadFile.run(download_params)
5152

5253
def _load(self) -> ModelSession:
53-
session = self._make_session(self.model_path)
54+
# TODO: support other runtime sessions
55+
session = OrtSession(self.model_path)
5456
self.model = RapidTextDetector(
5557
OcrOptions(
5658
session=session.session,
@@ -62,17 +64,23 @@ def _load(self) -> ModelSession:
6264
)
6365
return session
6466

65-
def configure(self, **kwargs: Any) -> None:
66-
self.max_resolution = kwargs.get("maxResolution", self.max_resolution)
67-
self.min_score = kwargs.get("minScore", self.min_score)
68-
self.score_mode = kwargs.get("scoreMode", self.score_mode)
69-
70-
def _predict(self, inputs: bytes | Image.Image, **kwargs: Any) -> TextDetectionOutput:
67+
def _predict(self, inputs: bytes | Image.Image) -> TextDetectionOutput:
7168
results = self.model(decode_cv2(inputs))
7269
if results.boxes is None or results.scores is None or results.img is None:
7370
return self._empty
7471
return {
75-
"resized": results.img,
72+
"image": results.img,
7673
"boxes": np.array(results.boxes, dtype=np.float32),
7774
"scores": np.array(results.scores, dtype=np.float32),
7875
}
76+
77+
def configure(self, **kwargs: Any) -> None:
78+
if (max_resolution := kwargs.get("maxResolution")) is not None:
79+
self.max_resolution = max_resolution
80+
self.model.limit_side_len = max_resolution
81+
if (min_score := kwargs.get("minScore")) is not None:
82+
self.min_score = min_score
83+
self.model.postprocess_op.box_thresh = min_score
84+
if (score_mode := kwargs.get("scoreMode")) is not None:
85+
self.score_mode = score_mode
86+
self.model.postprocess_op.score_mode = score_mode

machine-learning/immich_ml/models/ocr/recognition.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ class TextRecognizer(InferenceModel):
2323
identity = (ModelType.RECOGNITION, ModelTask.OCR)
2424

2525
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
26-
self.min_score = model_kwargs.get("minScore", 0.5)
26+
self.min_score = model_kwargs.get("minScore", 0.9)
2727
self._empty: TextRecognitionOutput = {
2828
"box": np.empty(0, dtype=np.float32),
29-
"boxScore": [],
29+
"boxScore": np.empty(0, dtype=np.float32),
3030
"text": [],
31-
"textScore": [],
31+
"textScore": np.empty(0, dtype=np.float32),
3232
}
3333
super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX)
3434

@@ -62,24 +62,20 @@ def _load(self) -> ModelSession:
6262
)
6363
return session
6464

65-
def configure(self, **kwargs: Any) -> None:
66-
self.min_score = kwargs.get("minScore", self.min_score)
67-
68-
def _predict(self, _: Image, texts: TextDetectionOutput, **kwargs: Any) -> TextRecognitionOutput:
69-
boxes, resized_img, box_scores = texts["boxes"], texts["resized"], texts["scores"]
65+
def _predict(self, _: Image, texts: TextDetectionOutput) -> TextRecognitionOutput:
66+
boxes, img, box_scores = texts["boxes"], texts["image"], texts["scores"]
7067
if boxes.shape[0] == 0:
7168
return self._empty
72-
rec = self.model(TextRecInput(img=self.get_crop_img_list(resized_img, boxes)))
69+
rec = self.model(TextRecInput(img=self.get_crop_img_list(img, boxes)))
7370
if rec.txts is None:
7471
return self._empty
7572

76-
height, width = resized_img.shape[0:2]
77-
log.info(f"Image shape: width={width}, height={height}")
73+
height, width = img.shape[0:2]
7874
boxes[:, :, 0] /= width
7975
boxes[:, :, 1] /= height
8076

8177
text_scores = np.array(rec.scores)
82-
valid_text_score_idx = text_scores > 0.5
78+
valid_text_score_idx = text_scores > self.min_score
8379
valid_score_idx_list = valid_text_score_idx.tolist()
8480
return {
8581
"box": boxes.reshape(-1, 8)[valid_text_score_idx].reshape(-1),
@@ -115,3 +111,6 @@ def get_crop_img_list(self, img: np.ndarray, boxes: np.ndarray) -> list[np.ndarr
115111
dst_img = np.rot90(dst_img)
116112
imgs.append(dst_img)
117113
return imgs
114+
115+
def configure(self, **kwargs: Any) -> None:
116+
self.min_score = kwargs.get("minScore", self.min_score)

machine-learning/immich_ml/models/ocr/schemas.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77

88

99
class TextDetectionOutput(TypedDict):
10-
resized: npt.NDArray[np.float32]
10+
image: npt.NDArray[np.float32]
1111
boxes: npt.NDArray[np.float32]
1212
scores: npt.NDArray[np.float32]
1313

1414

1515
class TextRecognitionOutput(TypedDict):
1616
box: npt.NDArray[np.float32]
17-
boxScore: Iterable[float]
17+
boxScore: npt.NDArray[np.float32]
1818
text: Iterable[str]
19-
textScore: Iterable[float]
19+
textScore: npt.NDArray[np.float32]
2020

2121

2222
# RapidOCR expects engine_type to be an attribute

machine-learning/immich_ml/sessions/ort.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
class OrtSession:
1717
session: ort.InferenceSession
18+
1819
def __init__(
1920
self,
2021
model_path: Path | str,

server/src/config.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ export const defaults = Object.freeze<SystemConfig>({
254254
enabled: true,
255255
modelName: 'PP-OCRv5_server',
256256
minDetectionScore: 0.5,
257-
minRecognitionScore: 0.5,
258-
maxResolution: 1440,
257+
minRecognitionScore: 0.9,
258+
maxResolution: 736,
259259
},
260260
},
261261
map: {

0 commit comments

Comments
 (0)