@@ -69,24 +69,28 @@ def _initialise(self):
6969 def complete (
7070 self ,
7171 prompt : str ,
72+ extract_json : bool = True ,
7273 process_text : Optional [Callable [[str ], str ]] = None ,
7374 ** kwargs : Any ,
74- ) -> Optional [ dict [str , Any ] ]:
75+ ) -> dict [str , Any ]:
7576 """Generates a completion response for the given prompt.
7677
7778 Args:
7879 prompt (str): The input text prompt for generating the completion.
80+ extract_json (bool, optional): If set to True, the response text is
81+ processed using a regex to extract JSON content from it. If no JSON is
82+ found, the text is returned as it is. Defaults to True.
7983 process_text (Optional[Callable[[str], str]], optional): A callable that
8084 processes the generated text and extracts specific information.
8185 Defaults to None.
8286 **kwargs (Any): Additional arguments passed to the completion function.
8387
8488 Returns:
85- Optional[ dict[str, Any]] : A dictionary containing the result of the
86- completion and processed output or None if the completion fails .
89+ dict[str, Any]: A dictionary containing the result of the completion
90+ and any processed output.
8791
8892 Raises:
89- Any : If an error occurs during the completion process, it will be
93+ LLMError : If an error occurs during the completion process, it will be
9094 raised after being processed by `parse_llm_err`.
9195 """
9296 try :
@@ -100,9 +104,10 @@ def complete(
100104 except Exception as e :
101105 logger .error (f"Error occured inside function 'process_text': { e } " )
102106 process_text_output = {}
103- match = LLM .json_regex .search (response .text )
104- if match :
105- response .text = match .group (0 )
107+ if extract_json :
108+ match = LLM .json_regex .search (response .text )
109+ if match :
110+ response .text = match .group (0 )
106111 return {LLM .RESPONSE : response , ** process_text_output }
107112 except Exception as e :
108113 raise parse_llm_err (e ) from e
@@ -189,7 +194,7 @@ def get_class_name(self) -> str:
189194 return self ._llm_instance .class_name ()
190195
191196 def get_model_name (self ) -> str :
192- """Gets the name of the LLM model
197+ """Gets the name of the LLM model.
193198
194199 Args:
195200 NA
0 commit comments