Skip to content

Commit d5263fb

Browse files
refactor: fix ocr_languages parameter type (#375)
This PR adds support for both `list[str]` and `str` input formats for `ocr_languages` parameter (e.g. ["eng", "deu"] or "eng+deu") ### Testing CI should pass.
1 parent f7d037f commit d5263fb

File tree

7 files changed

+28
-9
lines changed

7 files changed

+28
-9
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
## 0.0.65-dev0
2+
* Add support for both `list[str]` and `str` input formats for `ocr_languages` parameter
3+
14
## 0.0.64
25
* Bump Pydantic to 2.5.x and remove it from explicit dependencies list (will be managed by fastapi)
36
* Introduce Form params description in the code, which will form openapi and swagger documentation

prepline_general/api/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
app = FastAPI(
1313
title="Unstructured Pipeline API",
1414
summary="Partition documents with the Unstructured library",
15-
version="0.0.64",
15+
version="0.0.65",
1616
docs_url="/general/docs",
1717
openapi_url="/general/openapi.json",
1818
servers=[

prepline_general/api/general.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ def return_content_type(filename: str):
688688

689689

690690
@router.get("/general/v0/general", include_in_schema=False)
691-
@router.get("/general/v0.0.64/general", include_in_schema=False)
691+
@router.get("/general/v0.0.65/general", include_in_schema=False)
692692
async def handle_invalid_get_request():
693693
raise HTTPException(
694694
status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Only POST requests are supported."
@@ -703,7 +703,7 @@ async def handle_invalid_get_request():
703703
description="Description",
704704
operation_id="partition_parameters",
705705
)
706-
@router.post("/general/v0.0.64/general", include_in_schema=False)
706+
@router.post("/general/v0.0.65/general", include_in_schema=False)
707707
def general_partition(
708708
request: Request,
709709
# cannot use annotated type here because of a bug described here:

prepline_general/api/models/form_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def as_form(
6161
description="The languages present in the document, for use in partitioning and/or OCR",
6262
example="[eng]",
6363
),
64-
# BeforeValidator(SmartValueParser[List[str]]().value_or_first_element),
64+
BeforeValidator(SmartValueParser[List[str]]().value_or_first_element),
6565
] = [],
6666
skip_infer_table_types: Annotated[
6767
List[str],

prepline_general/api/utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,15 @@ def _return_cast_first_element(values: list[E], origin_class: type) -> E | None:
3939

4040

4141
def is_convertible_to_list(s: str) -> Tuple[bool, Union[List, str]]:
42-
"""Determines if a given string is convertible to a list by attempting to parse it as JSON."""
42+
"""
43+
Determines if a given string is convertible to a list.
44+
45+
This function first tries to parse the string as JSON. If the parsed JSON is a list, it returns
46+
True along with the list. If parsing as JSON fails, it then checks if the string can be split
47+
into a list using predefined delimiters ("," or "+"). If so, it returns True and the resulting list.
48+
If neither condition is met, it returns False and a message indicating the string cannot
49+
be converted to a list.
50+
"""
4351

4452
try:
4553
result = json.loads(s)
@@ -48,7 +56,14 @@ def is_convertible_to_list(s: str) -> Tuple[bool, Union[List, str]]:
4856
else:
4957
return False, "Input is valid JSON but not a list." # Valid JSON but not a list
5058
except json.JSONDecodeError:
51-
return False, "Input is not valid JSON." # Invalid JSON
59+
pass # proceed to check using delimiters if JSON parsing fails
60+
61+
delimiters = ["+", ","]
62+
for delimiter in delimiters:
63+
if delimiter in delimiters:
64+
return True, s.split(delimiter)
65+
66+
return False, "Input is not valid JSON." # Invalid JSON
5267

5368

5469
class SmartValueParser(Generic[T]):

preprocessing-pipeline-family.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
name: general
2-
version: 0.0.64
2+
version: 0.0.65

test_general/api/test_app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ def test_metadata_fields_removed():
140140
assert "detection_class_prob" not in response_without_coords[i]["metadata"]
141141

142142

143-
def test_ocr_languages_param(): # will eventually be depricated
143+
@pytest.mark.parametrize("ocr_languages", [["eng", "kor"], ["eng+kor"]])
144+
def test_ocr_languages_param(ocr_languages): # will eventually be deprecated
144145
"""
145146
Verify that we get the corresponding languages from the response with ocr_languages
146147
"""
@@ -149,7 +150,7 @@ def test_ocr_languages_param(): # will eventually be depricated
149150
response = client.post(
150151
MAIN_API_ROUTE,
151152
files=[("files", (str(test_file), open(test_file, "rb")))],
152-
data={"strategy": "ocr_only", "ocr_languages": ["eng", "kor"]},
153+
data={"strategy": "ocr_only", "ocr_languages": ocr_languages},
153154
)
154155

155156
assert response.status_code == 200

0 commit comments

Comments
 (0)