Skip to content

Commit 28434ca

Browse files
Added support to pass functions to index and llm complete functions (#79)
* Added support to pass functions to index and llm complete functions Signed-off-by: Deepak <[email protected]> * Added docstring Signed-off-by: Deepak <[email protected]> --------- Signed-off-by: Deepak <[email protected]> Signed-off-by: Deepak K <[email protected]> Co-authored-by: Gayathri <[email protected]>
1 parent dea9bc3 commit 28434ca

File tree

2 files changed

+42
-26
lines changed

2 files changed

+42
-26
lines changed

src/unstract/sdk/index.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
2-
from typing import Any, Optional
2+
import logging
3+
from typing import Any, Callable, Optional
34

45
from llama_index.core import Document
56
from llama_index.core.node_parser import SimpleNodeParser
@@ -25,6 +26,8 @@
2526
from unstract.sdk.vector_db import VectorDB
2627
from unstract.sdk.x2txt import X2Text
2728

29+
logger = logging.getLogger(__name__)
30+
2831

2932
class Constants:
3033
TOP_K = 5
@@ -101,27 +104,6 @@ def query_index(
101104
finally:
102105
vector_db.close()
103106

104-
def _cleanup_text(self, full_text):
105-
# Remove text which is not required
106-
full_text_lines = full_text.split("\n")
107-
new_context_lines = []
108-
empty_line_count = 0
109-
for line in full_text_lines:
110-
if line.strip() == "":
111-
empty_line_count += 1
112-
else:
113-
if empty_line_count >= 3:
114-
empty_line_count = 3
115-
for i in range(empty_line_count):
116-
new_context_lines.append("")
117-
empty_line_count = 0
118-
new_context_lines.append(line.rstrip())
119-
self.tool.stream_log(
120-
f"Old context length: {len(full_text_lines)}, "
121-
f"New context length: {len(new_context_lines)}"
122-
)
123-
return "\n".join(new_context_lines)
124-
125107
def index(
126108
self,
127109
tool_id: str,
@@ -136,6 +118,7 @@ def index(
136118
output_file_path: Optional[str] = None,
137119
enable_highlight: bool = False,
138120
usage_kwargs: dict[Any, Any] = {},
121+
process_text: Optional[Callable[[str], str]] = None,
139122
) -> str:
140123
"""Indexes an individual file using the passed arguments.
141124
@@ -276,10 +259,17 @@ def index(
276259
except AdapterError as e:
277260
# Wrapping AdapterErrors with SdkError
278261
raise IndexingError(str(e)) from e
262+
if process_text:
263+
try:
264+
result = process_text(extracted_text)
265+
if isinstance(result, str):
266+
extracted_text = result
267+
except Exception as e:
268+
logger.error(f"Error occured inside function 'process_text': {e}")
279269
full_text.append(
280270
{
281271
"section": "full",
282-
"text_contents": self._cleanup_text(extracted_text),
272+
"text_contents": extracted_text,
283273
}
284274
)
285275

src/unstract/sdk/llm.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import re
3-
from typing import Any, Optional
3+
from typing import Any, Callable, Optional
44

55
from llama_index.core.base.llms.types import CompletionResponseGen
66
from llama_index.core.llms import LLM as LlamaIndexLLM
@@ -69,15 +69,41 @@ def _initialise(self):
6969
def complete(
7070
self,
7171
prompt: str,
72-
retries: int = 3,
72+
process_text: Optional[Callable[[str], str]] = None,
7373
**kwargs: Any,
7474
) -> Optional[dict[str, Any]]:
75+
"""Generates a completion response for the given prompt.
76+
77+
Args:
78+
prompt (str): The input text prompt for generating the completion.
79+
process_text (Optional[Callable[[str], str]], optional): A callable that
80+
processes the generated text and extracts specific information.
81+
Defaults to None.
82+
**kwargs (Any): Additional arguments passed to the completion function.
83+
84+
Returns:
85+
Optional[dict[str, Any]]: A dictionary containing the result of the
86+
completion and processed output or None if the completion fails.
87+
88+
Raises:
89+
Any: If an error occurs during the completion process, it will be
90+
raised after being processed by `parse_llm_err`.
91+
"""
7592
try:
7693
response: CompletionResponse = self._llm_instance.complete(prompt, **kwargs)
94+
process_text_output = {}
95+
if process_text:
96+
try:
97+
process_text_output = process_text(response, LLM.json_regex)
98+
if not isinstance(process_text_output, dict):
99+
process_text_output = {}
100+
except Exception as e:
101+
logger.error(f"Error occured inside function 'process_text': {e}")
102+
process_text_output = {}
77103
match = LLM.json_regex.search(response.text)
78104
if match:
79105
response.text = match.group(0)
80-
return {LLM.RESPONSE: response}
106+
return {LLM.RESPONSE: response, **process_text_output}
81107
except Exception as e:
82108
raise parse_llm_err(e) from e
83109

0 commit comments

Comments
 (0)