1616
1717
1818class SimplifiedLLM (Step ):
19+ # Models that don't support native JSON mode
20+ JSON_MODE_UNSUPPORTED_MODELS = {
21+ "gemini-2.0-flash-thinking-exp" ,
22+ # Add other models here as needed
23+ }
24+
1925 def __init__ (self , inputs ):
2026 super ().__init__ (inputs )
2127 missing_keys = SimplifiedLLMInputs .__required_keys__ .difference (set (inputs .keys ()))
@@ -28,6 +34,7 @@ def __init__(self, inputs):
2834 self .is_json_mode = inputs .get ("json" , False )
2935 self .json_example = inputs .get ("json_example" )
3036 self .inputs = inputs
37+ self .is_json_mode_unsupported = inputs .get ("model" ) in self .JSON_MODE_UNSUPPORTED_MODELS
3138
3239 def __record_status_or_raise (self , retry_data : RetryData , step : Step ):
3340 if retry_data .retry_count == retry_data .retry_limit or step .status != StepStatus .FAILED :
@@ -49,15 +56,31 @@ def __json_loads(json_str: str) -> dict:
4956 logger .debug (f"Json to decode: \n { json_str } \n Error: \n { e } " )
5057 raise e
5158
59+ @staticmethod
60+ def __extract_json_from_text (text : str ) -> str :
61+ try :
62+ start = text .find ("{" )
63+ end = text .rfind ("}" )
64+ if start != - 1 and end != - 1 :
65+ return text [start : end + 1 ]
66+ return text
67+ except Exception :
68+ return text
69+
5270 def __retry_unit (self , prepare_prompt_outputs , call_llm_inputs , retry_data : RetryData ):
5371 call_llm = CallLLM (call_llm_inputs )
5472 call_llm_outputs = call_llm .run ()
5573 self .__record_status_or_raise (retry_data , call_llm )
5674
5775 if self .is_json_mode :
5876 json_responses = []
77+
5978 for response in call_llm_outputs .get ("openai_responses" ):
6079 try :
80+ # For models that don't support JSON mode, extract JSON from the text response first
81+ if self .is_json_mode_unsupported :
82+ response = self .__extract_json_from_text (response )
83+
6184 json_response = self .__json_loads (response )
6285 json_responses .append (json_response )
6386 except json .JSONDecodeError as e :
@@ -91,6 +114,14 @@ def run(self) -> dict:
91114 prompts = [dict (role = "user" , content = self .user )]
92115 if self .system :
93116 prompts .insert (0 , dict (role = "system" , content = self .system ))
117+
118+ # Special handling for models that don't support JSON mode
119+ if self .is_json_mode_unsupported and self .is_json_mode and self .json_example :
120+ # Append JSON example to user message
121+ prompts [- 1 ][
122+ "content"
123+ ] += f"\n Please format your response as a JSON object like this example:\n { json .dumps (self .json_example , indent = 2 )} "
124+
94125 prepare_prompt_inputs = dict (
95126 prompt_template = prompts ,
96127 prompt_values = self .prompt_values ,
@@ -100,9 +131,14 @@ def run(self) -> dict:
100131 self .set_status (prepare_prompt .status , prepare_prompt .status_message )
101132
102133 model_keys = [key for key in self .inputs .keys () if key .startswith ("model_" )]
103- response_format = dict (type = "json_object" if self .is_json_mode else "text" )
104- if self .json_example is not None :
105- response_format = example_json_to_schema (self .json_example )
134+
135+ # Set response format based on model and mode
136+ response_format = None
137+ if not self .is_json_mode_unsupported :
138+ response_format = dict (type = "json_object" if self .is_json_mode else "text" )
139+ if self .json_example is not None :
140+ response_format = example_json_to_schema (self .json_example )
141+
106142 call_llm_inputs = {
107143 "prompts" : prepare_prompt_outputs .get ("prompts" ),
108144 ** {
0 commit comments