Skip to content

Commit 1b9546f

Browse files
[FIX] JSON regex check based on param (#98)
* JSON regex check based on param Signed-off-by: Deepak <[email protected]> * Made extract_json param default value as true Signed-off-by: Deepak <[email protected]> * Version bump Signed-off-by: Deepak <[email protected]> --------- Signed-off-by: Deepak <[email protected]>
1 parent 41525a4 commit 1b9546f

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

src/unstract/sdk/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.48.0"
1+
__version__ = "0.48.1"
22

33

44
def get_sdk_version():

src/unstract/sdk/llm.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,28 @@ def _initialise(self):
6969
def complete(
7070
self,
7171
prompt: str,
72+
extract_json: bool = True,
7273
process_text: Optional[Callable[[str], str]] = None,
7374
**kwargs: Any,
74-
) -> Optional[dict[str, Any]]:
75+
) -> dict[str, Any]:
7576
"""Generates a completion response for the given prompt.
7677
7778
Args:
7879
prompt (str): The input text prompt for generating the completion.
80+
extract_json (bool, optional): If set to True, the response text is
81+
processed using a regex to extract JSON content from it. If no JSON is
82+
found, the text is returned as it is. Defaults to True.
7983
process_text (Optional[Callable[[str], str]], optional): A callable that
8084
processes the generated text and extracts specific information.
8185
Defaults to None.
8286
**kwargs (Any): Additional arguments passed to the completion function.
8387
8488
Returns:
85-
Optional[dict[str, Any]]: A dictionary containing the result of the
86-
completion and processed output or None if the completion fails.
89+
dict[str, Any]: A dictionary containing the result of the completion
90+
and any processed output.
8791
8892
Raises:
89-
Any: If an error occurs during the completion process, it will be
93+
LLMError: If an error occurs during the completion process, it will be
9094
raised after being processed by `parse_llm_err`.
9195
"""
9296
try:
@@ -100,9 +104,10 @@ def complete(
100104
except Exception as e:
101105
logger.error(f"Error occured inside function 'process_text': {e}")
102106
process_text_output = {}
103-
match = LLM.json_regex.search(response.text)
104-
if match:
105-
response.text = match.group(0)
107+
if extract_json:
108+
match = LLM.json_regex.search(response.text)
109+
if match:
110+
response.text = match.group(0)
106111
return {LLM.RESPONSE: response, **process_text_output}
107112
except Exception as e:
108113
raise parse_llm_err(e) from e
@@ -189,7 +194,7 @@ def get_class_name(self) -> str:
189194
return self._llm_instance.class_name()
190195

191196
def get_model_name(self) -> str:
192-
"""Gets the name of the LLM model
197+
"""Gets the name of the LLM model.
193198
194199
Args:
195200
NA

0 commit comments

Comments
 (0)