Skip to content

Commit a4cfca0

Browse files
authored
Add support for json_mode unsupported models (#1239)
* Add support for json_mode unsupported models * w -> __extract_json_from_text * Update to version 0.0.95
1 parent 031849e commit a4cfca0

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

patchwork/steps/SimplifiedLLM/SimplifiedLLM.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616

1717

1818
class SimplifiedLLM(Step):
19+
# Models that don't support native JSON mode
20+
JSON_MODE_UNSUPPORTED_MODELS = {
21+
"gemini-2.0-flash-thinking-exp",
22+
# Add other models here as needed
23+
}
24+
1925
def __init__(self, inputs):
2026
super().__init__(inputs)
2127
missing_keys = SimplifiedLLMInputs.__required_keys__.difference(set(inputs.keys()))
@@ -28,6 +34,7 @@ def __init__(self, inputs):
2834
self.is_json_mode = inputs.get("json", False)
2935
self.json_example = inputs.get("json_example")
3036
self.inputs = inputs
37+
self.is_json_mode_unsupported = inputs.get("model") in self.JSON_MODE_UNSUPPORTED_MODELS
3138

3239
def __record_status_or_raise(self, retry_data: RetryData, step: Step):
3340
if retry_data.retry_count == retry_data.retry_limit or step.status != StepStatus.FAILED:
@@ -49,15 +56,31 @@ def __json_loads(json_str: str) -> dict:
4956
logger.debug(f"Json to decode: \n{json_str}\nError: \n{e}")
5057
raise e
5158

59+
@staticmethod
60+
def __extract_json_from_text(text: str) -> str:
61+
try:
62+
start = text.find("{")
63+
end = text.rfind("}")
64+
if start != -1 and end != -1:
65+
return text[start : end + 1]
66+
return text
67+
except Exception:
68+
return text
69+
5270
def __retry_unit(self, prepare_prompt_outputs, call_llm_inputs, retry_data: RetryData):
5371
call_llm = CallLLM(call_llm_inputs)
5472
call_llm_outputs = call_llm.run()
5573
self.__record_status_or_raise(retry_data, call_llm)
5674

5775
if self.is_json_mode:
5876
json_responses = []
77+
5978
for response in call_llm_outputs.get("openai_responses"):
6079
try:
80+
# For models that don't support JSON mode, extract JSON from the text response first
81+
if self.is_json_mode_unsupported:
82+
response = self.__extract_json_from_text(response)
83+
6184
json_response = self.__json_loads(response)
6285
json_responses.append(json_response)
6386
except json.JSONDecodeError as e:
@@ -91,6 +114,14 @@ def run(self) -> dict:
91114
prompts = [dict(role="user", content=self.user)]
92115
if self.system:
93116
prompts.insert(0, dict(role="system", content=self.system))
117+
118+
# Special handling for models that don't support JSON mode
119+
if self.is_json_mode_unsupported and self.is_json_mode and self.json_example:
120+
# Append JSON example to user message
121+
prompts[-1][
122+
"content"
123+
] += f"\nPlease format your response as a JSON object like this example:\n{json.dumps(self.json_example, indent=2)}"
124+
94125
prepare_prompt_inputs = dict(
95126
prompt_template=prompts,
96127
prompt_values=self.prompt_values,
@@ -100,9 +131,14 @@ def run(self) -> dict:
100131
self.set_status(prepare_prompt.status, prepare_prompt.status_message)
101132

102133
model_keys = [key for key in self.inputs.keys() if key.startswith("model_")]
103-
response_format = dict(type="json_object" if self.is_json_mode else "text")
104-
if self.json_example is not None:
105-
response_format = example_json_to_schema(self.json_example)
134+
135+
# Set response format based on model and mode
136+
response_format = None
137+
if not self.is_json_mode_unsupported:
138+
response_format = dict(type="json_object" if self.is_json_mode else "text")
139+
if self.json_example is not None:
140+
response_format = example_json_to_schema(self.json_example)
141+
106142
call_llm_inputs = {
107143
"prompts": prepare_prompt_outputs.get("prompts"),
108144
**{

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "patchwork-cli"
3-
version = "0.0.94"
3+
version = "0.0.95"
44
description = ""
55
authors = ["patched.codes"]
66
license = "AGPL"

0 commit comments

Comments
 (0)