7171from .pdl_parser import PDLParseError , parse_file
7272from .pdl_scheduler import (
7373 CodeYieldResultMessage ,
74+ GeneratorWrapper ,
7475 ModelCallMessage ,
7576 ModelYieldResultMessage ,
7677 YieldBackgroundMessage ,
@@ -1058,7 +1059,9 @@ def get_transformed_inputs(kwargs):
10581059
10591060 litellm .input_callback = [get_transformed_inputs ]
10601061 # append_log(state, "Model Input", messages_to_str(model_input))
1061- msg = yield from generate_client_response (state , concrete_block , model_input )
1062+ msg , raw_result = yield from generate_client_response (
1063+ state , concrete_block , model_input
1064+ )
10621065 if "input" in litellm_params :
10631066 append_log (state , "Model Input" , litellm_params ["input" ])
10641067 else :
@@ -1069,6 +1072,8 @@ def get_transformed_inputs(kwargs):
10691072 result = msg ["content" ]
10701073 append_log (state , "Model Output" , result )
10711074 trace = block .model_copy (update = {"result" : result , "trace" : concrete_block })
1075+ if block .modelResponse is not None :
1076+ scope = scope | {block .modelResponse : raw_result }
10721077 return result , background , scope , trace
10731078 except Exception as exc :
10741079 message = f"Error during model call: { repr (exc )} "
@@ -1083,29 +1088,30 @@ def generate_client_response( # pylint: disable=too-many-arguments
10831088 state : InterpreterState ,
10841089 block : BamModelBlock | LitellmModelBlock ,
10851090 model_input : Messages ,
1086- ) -> Generator [YieldMessage , Any , Message ]:
1091+ ) -> Generator [YieldMessage , Any , tuple [Message , Any ]]:
1092+ raw_result = None
10871093 match state .batch :
10881094 case 0 :
1089- model_output = yield from generate_client_response_streaming (
1095+ model_output , raw_result = yield from generate_client_response_streaming (
10901096 state , block , model_input
10911097 )
10921098 case 1 :
1093- model_output = yield from generate_client_response_single (
1099+ model_output , raw_result = yield from generate_client_response_single (
10941100 state , block , model_input
10951101 )
10961102 case _:
10971103 model_output = yield from generate_client_response_batching (
10981104 state , block , model_input
10991105 )
1100- return model_output
1106+ return model_output , raw_result
11011107
11021108
11031109def generate_client_response_streaming (
11041110 state : InterpreterState ,
11051111 block : BamModelBlock | LitellmModelBlock ,
11061112 model_input : Messages ,
1107- ) -> Generator [YieldMessage , Any , Message ]:
1108- msg_stream : Generator [Message , Any , None ]
1113+ ) -> Generator [YieldMessage , Any , tuple [ Message , Any ] ]:
1114+ msg_stream : Generator [Message , Any , Any ]
11091115 model_input_str = messages_to_str (block .model , model_input )
11101116 match block :
11111117 case BamModelBlock ():
@@ -1127,7 +1133,8 @@ def generate_client_response_streaming(
11271133 assert False
11281134 complete_msg : Optional [Message ] = None
11291135 role = None
1130- for chunk in msg_stream :
1136+ wrapped_gen = GeneratorWrapper (msg_stream )
1137+ for chunk in wrapped_gen :
11311138 if state .yield_result :
11321139 yield ModelYieldResultMessage (chunk ["content" ])
11331140 if state .yield_background :
@@ -1139,9 +1146,12 @@ def generate_client_response_streaming(
11391146 chunk_role = chunk ["role" ]
11401147 if chunk_role is None or chunk_role == role :
11411148 complete_msg ["content" ] += chunk ["content" ]
1149+ raw_result = None
1150+ if block .modelResponse is not None :
1151+ raw_result = wrapped_gen .value
11421152 if complete_msg is None :
1143- return Message (role = state .role , content = "" )
1144- return complete_msg
1153+ return Message (role = state .role , content = "" ), raw_result
1154+ return complete_msg , raw_result
11451155
11461156
11471157def litellm_parameters_to_dict (
@@ -1159,12 +1169,12 @@ def generate_client_response_single(
11591169 state : InterpreterState ,
11601170 block : BamModelBlock | LitellmModelBlock ,
11611171 model_input : Messages ,
1162- ) -> Generator [YieldMessage , Any , Message ]:
1172+ ) -> Generator [YieldMessage , Any , tuple [ Message , Any ] ]:
11631173 msg : Message
11641174 model_input_str = messages_to_str (block .model , model_input )
11651175 match block :
11661176 case BamModelBlock ():
1167- msg = BamModel .generate_text (
1177+ msg , raw_result = BamModel .generate_text (
11681178 model_id = block .model ,
11691179 prompt_id = block .prompt_id ,
11701180 model_input = model_input_str ,
@@ -1173,7 +1183,7 @@ def generate_client_response_single(
11731183 data = block .data ,
11741184 )
11751185 case LitellmModelBlock ():
1176- msg = LitellmModel .generate_text (
1186+ msg , raw_result = LitellmModel .generate_text (
11771187 model_id = block .model ,
11781188 messages = model_input ,
11791189 parameters = litellm_parameters_to_dict (block .parameters ),
@@ -1182,7 +1192,7 @@ def generate_client_response_single(
11821192 yield YieldResultMessage (msg ["content" ])
11831193 if state .yield_background :
11841194 yield YieldBackgroundMessage ([msg ])
1185- return msg
1195+ return msg , raw_result
11861196
11871197
11881198def generate_client_response_batching ( # pylint: disable=too-many-arguments
0 commit comments