diff --git a/patchwork/steps/SimplifiedLLM/SimplifiedLLM.py b/patchwork/steps/SimplifiedLLM/SimplifiedLLM.py index f37dc9035..167416a80 100644 --- a/patchwork/steps/SimplifiedLLM/SimplifiedLLM.py +++ b/patchwork/steps/SimplifiedLLM/SimplifiedLLM.py @@ -16,6 +16,12 @@ class SimplifiedLLM(Step): + # Models that don't support native JSON mode + JSON_MODE_UNSUPPORTED_MODELS = { + "gemini-2.0-flash-thinking-exp", + # Add other models here as needed + } + def __init__(self, inputs): super().__init__(inputs) missing_keys = SimplifiedLLMInputs.__required_keys__.difference(set(inputs.keys())) @@ -28,6 +34,7 @@ def __init__(self, inputs): self.is_json_mode = inputs.get("json", False) self.json_example = inputs.get("json_example") self.inputs = inputs + self.is_json_mode_unsupported = inputs.get("model") in self.JSON_MODE_UNSUPPORTED_MODELS def __record_status_or_raise(self, retry_data: RetryData, step: Step): if retry_data.retry_count == retry_data.retry_limit or step.status != StepStatus.FAILED: @@ -49,6 +56,17 @@ def __json_loads(json_str: str) -> dict: logger.debug(f"Json to decode: \n{json_str}\nError: \n{e}") raise e + @staticmethod + def __extract_json_from_text(text: str) -> str: + try: + start = text.find("{") + end = text.rfind("}") + if start != -1 and end != -1: + return text[start : end + 1] + return text + except Exception: + return text + def __retry_unit(self, prepare_prompt_outputs, call_llm_inputs, retry_data: RetryData): call_llm = CallLLM(call_llm_inputs) call_llm_outputs = call_llm.run() @@ -56,8 +74,13 @@ def __retry_unit(self, prepare_prompt_outputs, call_llm_inputs, retry_data: Retr if self.is_json_mode: json_responses = [] + for response in call_llm_outputs.get("openai_responses"): try: + # For models that don't support JSON mode, extract JSON from the text response first + if self.is_json_mode_unsupported: + response = self.__extract_json_from_text(response) + json_response = self.__json_loads(response) json_responses.append(json_response) except json.JSONDecodeError as e: @@ -91,6 +114,14 @@ def run(self) -> dict: prompts = [dict(role="user", content=self.user)] if self.system: prompts.insert(0, dict(role="system", content=self.system)) + + # Special handling for models that don't support JSON mode + if self.is_json_mode_unsupported and self.is_json_mode and self.json_example: + # Append JSON example to user message + prompts[-1][ + "content" + ] += f"\nPlease format your response as a JSON object like this example:\n{json.dumps(self.json_example, indent=2)}" + prepare_prompt_inputs = dict( prompt_template=prompts, prompt_values=self.prompt_values, @@ -100,9 +131,14 @@ def run(self) -> dict: self.set_status(prepare_prompt.status, prepare_prompt.status_message) model_keys = [key for key in self.inputs.keys() if key.startswith("model_")] - response_format = dict(type="json_object" if self.is_json_mode else "text") - if self.json_example is not None: - response_format = example_json_to_schema(self.json_example) + + # Set response format based on model and mode + response_format = None + if not self.is_json_mode_unsupported: + response_format = dict(type="json_object" if self.is_json_mode else "text") + if self.json_example is not None: + response_format = example_json_to_schema(self.json_example) + call_llm_inputs = { "prompts": prepare_prompt_outputs.get("prompts"), **{ diff --git a/pyproject.toml b/pyproject.toml index c46b86bf1..0ea3dc951 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "patchwork-cli" -version = "0.0.94" +version = "0.0.95" description = "" authors = ["patched.codes"] license = "AGPL"