diff --git a/extract_thinker/concatenation_handler.py b/extract_thinker/concatenation_handler.py index f8f51ae..4d74fb8 100644 --- a/extract_thinker/concatenation_handler.py +++ b/extract_thinker/concatenation_handler.py @@ -27,6 +27,7 @@ def _is_valid_json_continuation(self, response: str) -> bool: return has_json_markers + def handle(self, content: Any, response_model: type[BaseModel], vision: bool = False, extra_content: Optional[str] = None) -> Any: self.json_parts = [] messages = self._build_messages(content, vision, response_model) @@ -38,27 +39,28 @@ def handle(self, content: Any, response_model: type[BaseModel], vision: bool = F max_retries = 3 while True: try: - response = self.llm.raw_completion(messages) - - # Validate if it's a proper JSON continuation - if not self._is_valid_json_continuation(response): - retry_count += 1 - if retry_count >= max_retries: - raise ValueError("Maximum retries reached with invalid JSON continuations") - continue - + response_obj = self.llm.raw_completion_complete(messages) + response = response_obj.message.content + finish_reason = response_obj.finish_reason + self.json_parts.append(response) - - # Try to process and validate the JSON - result = self._process_json_parts(response_model) - return result - + + if self._is_finish(response_obj): + result = self._process_json_parts(response_model) + return result + + retry_count += 1 + if retry_count >= max_retries: + raise ValueError("Maximum retries reached with incomplete response") + messages = self._build_continuation_messages(messages, response) + except ValueError as e: if retry_count >= max_retries: raise ValueError(f"Maximum retries reached: {str(e)}") retry_count += 1 messages = self._build_continuation_messages(messages, response) - + + def _process_json_parts(self, response_model: type[BaseModel]) -> Any: """Process collected JSON parts into a complete response.""" if not self.json_parts: diff --git a/extract_thinker/llm.py b/extract_thinker/llm.py index 95c9971..828916a 100644 --- a/extract_thinker/llm.py +++ b/extract_thinker/llm.py @@ -322,6 +322,52 @@ def raw_completion(self, messages: List[Dict[str, str]]) -> str: raw_response = litellm.completion(**params) return raw_response.choices[0].message.content + + def raw_completion_complete(self, messages: List[Dict[str, str]]) -> str: + """Make raw completion request without response model.""" + if self.backend == LLMEngine.PYDANTIC_AI: + # Combine messages into a single prompt + combined_prompt = " ".join([m["content"] for m in messages]) + try: + result = asyncio.run( + self.agent.run( + combined_prompt, + result_type=str + ) + ) + return result.data + except Exception as e: + raise ValueError(f"Failed to extract from source: {str(e)}") + + max_tokens = self.DEFAULT_OUTPUT_TOKENS + if self.token_limit is not None: + max_tokens = self.token_limit + elif self.is_thinking: + max_tokens = self.thinking_token_limit + + params = { + "model": self.model, + "messages": messages, + "max_completion_tokens": max_tokens, + } + + if self.is_thinking: + if litellm.supports_reasoning(self.model): + # Add thinking parameter for supported models + thinking_param = { + "type": "enabled", + "budget_tokens": self.thinking_budget + } + params["thinking"] = thinking_param + else: + print(f"Warning: Model {self.model} doesn't support thinking parameter, proceeding without it.") + + if self.router: + raw_response = self.router.completion(**params) + else: + raw_response = litellm.completion(**params) + + return raw_response.choices[0] def set_timeout(self, timeout_ms: int) -> None: """Set the timeout value for LLM requests in milliseconds."""