Skip to content

Commit dea9bc3

Browse files
Added new default configs for llm whisperer (#78)
* Added default config for add_line_nos and output_json Signed-off-by: Deepak <[email protected]> * Write metadata to file if output_file_path is provided Signed-off-by: Deepak <[email protected]> * Fixed metadata file path Signed-off-by: Deepak <[email protected]> * Added docstring Signed-off-by: Deepak <[email protected]> * Resolved review comment --------- Signed-off-by: Deepak <[email protected]> Signed-off-by: Deepak K <[email protected]>
1 parent f8a425a commit dea9bc3

File tree

2 files changed

+60
-13
lines changed

2 files changed

+60
-13
lines changed

src/unstract/sdk/adapters/x2text/llm_whisperer/src/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class WhispererConfig:
6161
HORIZONTAL_STRETCH_FACTOR = "horizontal_stretch_factor"
6262
PAGES_TO_EXTRACT = "pages_to_extract"
6363
STORE_METADATA_FOR_HIGHLIGHTING = "store_metadata_for_highlighting"
64+
ADD_LINE_NOS = "add_line_nos"
65+
OUTPUT_JSON = "output_json"
6466
PAGE_SEPARATOR = "page_seperator"
6567

6668

@@ -87,4 +89,6 @@ class WhispererDefaults:
8789
POLL_INTERVAL = int(os.getenv(WhispererEnv.POLL_INTERVAL, 30))
8890
MAX_POLLS = int(os.getenv(WhispererEnv.MAX_POLLS, 30))
8991
PAGES_TO_EXTRACT = ""
92+
ADD_LINE_NOS = True
93+
OUTPUT_JSON = True
9094
PAGE_SEPARATOR = "<<< >>>"

src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import json
12
import logging
23
import os
34
import time
5+
from pathlib import Path
46
from typing import Any, Optional
57

68
import requests
@@ -162,6 +164,8 @@ def _get_whisper_params(self, enable_highlight: bool = False) -> dict[str, Any]:
162164
WhispererConfig.PAGES_TO_EXTRACT,
163165
WhispererDefaults.PAGES_TO_EXTRACT,
164166
),
167+
WhispererConfig.ADD_LINE_NOS: WhispererDefaults.ADD_LINE_NOS,
168+
WhispererConfig.OUTPUT_JSON: WhispererDefaults.OUTPUT_JSON,
165169
WhispererConfig.PAGE_SEPARATOR: self.config.get(
166170
WhispererConfig.PAGE_SEPARATOR,
167171
WhispererDefaults.PAGE_SEPARATOR,
@@ -264,7 +268,10 @@ def _extract_async(self, whisper_hash: str) -> str:
264268
logger.info(f"Extracting async for whisper hash: {whisper_hash}")
265269

266270
headers: dict[str, Any] = self._get_request_headers()
267-
params = {WhisperStatus.WHISPER_HASH: whisper_hash}
271+
params = {
272+
WhisperStatus.WHISPER_HASH: whisper_hash,
273+
WhispererConfig.OUTPUT_JSON: WhispererDefaults.OUTPUT_JSON,
274+
}
268275

269276
# Polls in fixed intervals and checks status
270277
self._check_status_until_ready(
@@ -278,7 +285,7 @@ def _extract_async(self, whisper_hash: str) -> str:
278285
params=params,
279286
)
280287
if retrieve_response.status_code == 200:
281-
return retrieve_response.content.decode("utf-8")
288+
return retrieve_response.json()
282289
else:
283290
raise ExtractorError(
284291
"Error retrieving from LLMWhisperer: "
@@ -310,25 +317,61 @@ def _send_whisper_request(
310317
def _extract_text_from_response(
311318
self, output_file_path: Optional[str], response: requests.Response
312319
) -> str:
313-
314-
output = ""
320+
output_json = {}
315321
if response.status_code == 200:
316-
output = response.content.decode("utf-8")
322+
output_json = response.json()
317323
elif response.status_code == 202:
318324
whisper_hash = response.json().get(WhisperStatus.WHISPER_HASH)
319-
output = self._extract_async(whisper_hash=whisper_hash)
325+
output_json = self._extract_async(whisper_hash=whisper_hash)
320326
else:
321327
raise ExtractorError("Couldn't extract text from file")
328+
if output_file_path:
329+
self._write_output_to_file(
330+
output_json=output_json,
331+
output_file_path=Path(output_file_path),
332+
)
333+
return output_json.get("text", "")
334+
335+
def _write_output_to_file(self, output_json: dict, output_file_path: Path) -> None:
336+
"""Writes the extracted text and metadata to the specified output file
337+
and metadata file.
338+
339+
Args:
340+
output_json (dict): The dictionary containing the extracted data,
341+
with "text" as the key for the main content.
342+
output_file_path (Path): The file path where the extracted text
343+
should be written.
322344
345+
Raises:
346+
ExtractorError: If there is an error while writing the output file.
347+
"""
323348
try:
324-
# Write output to a file
325-
if output_file_path:
326-
with open(output_file_path, "w", encoding="utf-8") as f:
327-
f.write(output)
328-
except OSError as e:
329-
logger.error(f"OS error while writing {output_file_path}: {e} ")
349+
text_output = output_json.get("text", "")
350+
logger.info(f"Writing output to {output_file_path}")
351+
output_file_path.write_text(text_output, encoding="utf-8")
352+
try:
353+
# Define the directory of the output file and metadata paths
354+
output_dir = output_file_path.parent
355+
metadata_dir = output_dir / "metadata"
356+
metadata_file_name = output_file_path.with_suffix(".json").name
357+
metadata_file_path = metadata_dir / metadata_file_name
358+
# Ensure the metadata directory exists
359+
metadata_dir.mkdir(parents=True, exist_ok=True)
360+
# Remove the "text" key from the metadata
361+
metadata = {
362+
key: value for key, value in output_json.items() if key != "text"
363+
}
364+
metadata_json = json.dumps(metadata, ensure_ascii=False, indent=4)
365+
logger.info(f"Writing metadata to {metadata_file_path}")
366+
metadata_file_path.write_text(metadata_json, encoding="utf-8")
367+
except Exception as e:
368+
logger.error(
369+
f"Error while writing metadata to {metadata_file_path}: {e}"
370+
)
371+
372+
except Exception as e:
373+
logger.error(f"Error while writing {output_file_path}: {e}")
330374
raise ExtractorError(str(e))
331-
return output
332375

333376
def process(
334377
self,

0 commit comments

Comments
 (0)