Skip to content

Commit 5a851a0

Browse files
authored
Chore: add support forxml_keep_tags param (#136)
* merge with encoding param * wrote xml keep tags param test * update changelog and readme * bump requirements * remove spaces in readme curl sample
1 parent e70e00d commit 5a851a0

File tree

7 files changed

+80
-8
lines changed

7 files changed

+80
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## 0.0.30-dev1
22

33
* Add support for `encoding` parameter
4+
* Add support for `xml_keep_tags` parameter
45
* Add env variables for additional parallel mode tweaking
56

67
## 0.0.29

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,19 @@ curl -X 'POST'
107107
| jq -C . | less -R
108108
```
109109

110+
#### XML Tags
111+
112+
When processing XML documents, set the `xml_keep_tags` parameter to `true` to retain the XML tags in the output. If not specified, it will simply extract the text from within the tags.
113+
114+
```
115+
curl -X 'POST'
116+
'https://api.unstructured.io/general/v0/general' \
117+
-H 'accept: application/json' \
118+
-H 'Content-Type: multipart/form-data' \
119+
-F 'files=@sample-docs/fake-xml.xml' \
120+
-F 'xml_keep_tags=true' \
121+
| jq -C . | less -R
122+
```
110123

111124
## Developer Quick Start
112125

pipeline-notebooks/pipeline-general.ipynb

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,7 @@
725725
" m_coordinates=[],\n",
726726
" m_ocr_languages=[],\n",
727727
" m_encoding=[],\n",
728+
" m_xml_keep_tags=[],\n",
728729
" file_content_type=None,\n",
729730
" response_type=\"application/json\"\n",
730731
"):\n",
@@ -752,6 +753,9 @@
752753
"\n",
753754
" encoding = m_encoding[0] if len(m_encoding) else None\n",
754755
" \n",
756+
" xml_keep_tags_str = (m_xml_keep_tags[0] if len(m_xml_keep_tags) else \"false\").lower()\n",
757+
" xml_keep_tags = xml_keep_tags_str == \"true\"\n",
758+
" \n",
755759
" try:\n",
756760
" if file_content_type == \"application/pdf\" and pdf_parallel_mode_enabled:\n",
757761
" elements = partition_pdf_splits(\n",
@@ -772,6 +776,7 @@
772776
" strategy=strategy,\n",
773777
" ocr_languages=ocr_languages,\n",
774778
" encoding=encoding,\n",
779+
" xml_keep_tags=xml_keep_tags,\n",
775780
" )\n",
776781
" except ValueError as e:\n",
777782
" if 'Invalid file' in e.args[0]:\n",

prepline_general/api/general.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def pipeline_api(
183183
m_coordinates=[],
184184
m_ocr_languages=[],
185185
m_encoding=[],
186+
m_xml_keep_tags=[],
186187
file_content_type=None,
187188
response_type="application/json",
188189
):
@@ -209,6 +210,9 @@ def pipeline_api(
209210

210211
encoding = m_encoding[0] if len(m_encoding) else None
211212

213+
xml_keep_tags_str = (m_xml_keep_tags[0] if len(m_xml_keep_tags) else "false").lower()
214+
xml_keep_tags = xml_keep_tags_str == "true"
215+
212216
try:
213217
if file_content_type == "application/pdf" and pdf_parallel_mode_enabled:
214218
elements = partition_pdf_splits(
@@ -229,6 +233,7 @@ def pipeline_api(
229233
strategy=strategy,
230234
ocr_languages=ocr_languages,
231235
encoding=encoding,
236+
xml_keep_tags=xml_keep_tags,
232237
)
233238
except ValueError as e:
234239
if "Invalid file" in e.args[0]:
@@ -372,6 +377,7 @@ def pipeline_1(
372377
coordinates: List[str] = Form(default=[]),
373378
ocr_languages: List[str] = Form(default=[]),
374379
encoding: List[str] = Form(default=[]),
380+
xml_keep_tags: List[str] = Form(default=[]),
375381
):
376382
if files:
377383
for file_index in range(len(files)):
@@ -410,6 +416,7 @@ def response_generator(is_multipart):
410416
m_coordinates=coordinates,
411417
m_ocr_languages=ocr_languages,
412418
m_encoding=encoding,
419+
m_xml_keep_tags=xml_keep_tags,
413420
response_type=media_type,
414421
filename=file.filename,
415422
file_content_type=file_content_type,

requirements/base.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ anyio==3.7.0
1111
# httpcore
1212
# starlette
1313
# watchfiles
14-
argilla==1.11.0
14+
argilla==1.12.0
1515
# via unstructured
1616
attrs==23.1.0
1717
# via jsonschema
@@ -115,7 +115,7 @@ jinja2==3.1.2
115115
# nbconvert
116116
# torch
117117
# unstructured-api-tools
118-
joblib==1.3.0
118+
joblib==1.3.1
119119
# via nltk
120120
jsonschema==4.17.3
121121
# via nbformat
@@ -369,7 +369,7 @@ traitlets==5.9.0
369369
# nbformat
370370
transformers==4.30.2
371371
# via unstructured-inference
372-
typer==0.9.0
372+
typer==0.7.0
373373
# via argilla
374374
types-requests==2.31.0.1
375375
# via unstructured-api-tools
@@ -387,7 +387,6 @@ typing-extensions==4.7.0
387387
# rich
388388
# starlette
389389
# torch
390-
# typer
391390
unstructured[local-inference]==0.7.10
392391
# via -r requirements/base.in
393392
unstructured-api-tools==0.10.7

requirements/test.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ appnope==0.1.3
1919
# via
2020
# ipykernel
2121
# ipython
22-
argilla==1.11.0
22+
argilla==1.12.0
2323
# via
2424
# -r requirements/base.txt
2525
# unstructured
@@ -266,7 +266,7 @@ jinja2==3.1.2
266266
# notebook
267267
# torch
268268
# unstructured-api-tools
269-
joblib==1.3.0
269+
joblib==1.3.1
270270
# via
271271
# -r requirements/base.txt
272272
# nltk
@@ -812,7 +812,7 @@ transformers==4.30.2
812812
# via
813813
# -r requirements/base.txt
814814
# unstructured-inference
815-
typer==0.9.0
815+
typer==0.7.0
816816
# via
817817
# -r requirements/base.txt
818818
# argilla
@@ -841,7 +841,6 @@ typing-extensions==4.7.0
841841
# rich
842842
# starlette
843843
# torch
844-
# typer
845844
unstructured[local-inference]==0.7.10
846845
# via -r requirements/base.txt
847846
unstructured-api-tools==0.10.7

test_general/api/test_app.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
import io
66
import pytest
7+
import re
78
import requests
89
import ast
910
import pandas as pd
@@ -236,6 +237,53 @@ def test_api_with_different_encodings():
236237
assert "invalid start byte" in str(excinfo.value)
237238

238239

240+
def test_xml_keep_tags_param():
241+
"""
242+
Verify that responses do not include xml tags unless requested
243+
"""
244+
client = TestClient(app)
245+
test_file = Path("sample-docs") / "fake-xml.xml"
246+
response = client.post(
247+
MAIN_API_ROUTE,
248+
files=[("files", (str(test_file), open(test_file, "rb")))],
249+
data={"strategy": "hi_res"},
250+
)
251+
assert response.status_code == 200
252+
response_without_xml_tags = response.json()
253+
254+
response = client.post(
255+
MAIN_API_ROUTE,
256+
files=[("files", (str(test_file), open(test_file, "rb")))],
257+
data={"xml_keep_tags": "true", "strategy": "hi_res"},
258+
)
259+
assert response.status_code == 200
260+
response_with_xml_tags = response.json()[3:] # skip the initial encoding tag(s)
261+
262+
# The responses should have the same content except for the xml tags
263+
response_with_xml_tags_index, response_without_xml_tags_index = 0, 0
264+
while response_without_xml_tags_index < len(response_without_xml_tags):
265+
xml_tagged_line = response_with_xml_tags[response_with_xml_tags_index]["text"]
266+
assert xml_tagged_line.startswith("<")
267+
assert xml_tagged_line.endswith(">")
268+
269+
# if there is content on this line, ensure it matches the content on the non tagged line
270+
xml_tagged_line_content = xml_tagged_line.split(">", 1)[1] # remove opening tag
271+
if not xml_tagged_line_content:
272+
response_with_xml_tags_index += 1
273+
274+
else:
275+
xml_tagged_line_content = xml_tagged_line_content.split("<", 1)[0] # remove closing tag
276+
277+
xml_untagged_line = response_without_xml_tags[response_without_xml_tags_index]["text"]
278+
xml_tagged_line_content_parsed = re.sub(
279+
"&amp;", "&", xml_tagged_line_content
280+
) # xml_keep_tags does not currently parse the inner content
281+
assert xml_tagged_line_content_parsed == xml_untagged_line
282+
283+
response_with_xml_tags_index += 1
284+
response_without_xml_tags_index += 1
285+
286+
239287
@pytest.mark.parametrize(
240288
"example_filename",
241289
[

0 commit comments

Comments
 (0)