Skip to content

Commit 5c3302b

Browse files
authored
Obtaining raw output from model calls (#165)
* model raw output
1 parent 5f74d3a commit 5c3302b

File tree

6 files changed

+75
-20
lines changed

6 files changed

+75
-20
lines changed

pdl-live/src/pdl_ast.d.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2529,6 +2529,7 @@ export type Trace4 =
25292529
| ErrorBlock
25302530
| EmptyBlock
25312531
| null;
2532+
export type Modelresponse = string | null;
25322533
export type Platform = "bam";
25332534
export type PromptId = string | null;
25342535
export type Parameters =
@@ -2738,6 +2739,7 @@ export type Trace5 =
27382739
| ErrorBlock
27392740
| EmptyBlock
27402741
| null;
2742+
export type Modelresponse1 = string | null;
27412743
export type Platform1 = "litellm";
27422744
export type Parameters1 =
27432745
| LitellmParameters
@@ -3319,6 +3321,7 @@ export interface LitellmModelBlock {
33193321
model: unknown;
33203322
input?: Input1;
33213323
trace?: Trace5;
3324+
modelResponse?: Modelresponse1;
33223325
platform?: Platform1;
33233326
parameters?: Parameters1;
33243327
}
@@ -3401,6 +3404,7 @@ export interface BamModelBlock {
34013404
model: unknown;
34023405
input?: Input;
34033406
trace?: Trace4;
3407+
modelResponse?: Modelresponse;
34043408
platform: Platform;
34053409
prompt_id?: PromptId;
34063410
parameters?: Parameters;

src/pdl/pdl-schema.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,18 @@
13111311
"default": null,
13121312
"title": "Trace"
13131313
},
1314+
"modelResponse": {
1315+
"anyOf": [
1316+
{
1317+
"type": "string"
1318+
},
1319+
{
1320+
"type": "null"
1321+
}
1322+
],
1323+
"default": null,
1324+
"title": "Modelresponse"
1325+
},
13141326
"platform": {
13151327
"const": "bam",
13161328
"enum": [
@@ -9132,6 +9144,18 @@
91329144
"default": null,
91339145
"title": "Trace"
91349146
},
9147+
"modelResponse": {
9148+
"anyOf": [
9149+
{
9150+
"type": "string"
9151+
},
9152+
{
9153+
"type": "null"
9154+
}
9155+
],
9156+
"default": null,
9157+
"title": "Modelresponse"
9158+
},
91359159
"platform": {
91369160
"const": "litellm",
91379161
"default": "litellm",

src/pdl/pdl_ast.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ class ModelBlock(Block):
256256
model: str | ExpressionType
257257
input: Optional["BlocksType"] = None
258258
trace: Optional["BlockType"] = None
259+
modelResponse: Optional[str] = None
259260

260261

261262
class BamModelBlock(ModelBlock):

src/pdl/pdl_dumper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def block_to_dict(block: pdl_ast.BlockType, json_compatible: bool) -> DumpedBloc
104104
d["data"] = block.data
105105
if block.constraints is not None:
106106
d["constraints"] = block.constraints
107+
if block.modelResponse is not None:
108+
d["modelResponse"] = block.modelResponse
107109
case LitellmModelBlock():
108110
d["platform"] = block.platform
109111
d["model"] = block.model
@@ -116,6 +118,8 @@ def block_to_dict(block: pdl_ast.BlockType, json_compatible: bool) -> DumpedBloc
116118
)
117119
else:
118120
d["parameters"] = block.parameters
121+
if block.modelResponse is not None:
122+
d["modelResponse"] = block.modelResponse
119123
case CodeBlock():
120124
d["lang"] = block.lang
121125
d["code"] = blocks_to_dict(block.code, json_compatible)

src/pdl/pdl_interpreter.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from .pdl_parser import PDLParseError, parse_file
7272
from .pdl_scheduler import (
7373
CodeYieldResultMessage,
74+
GeneratorWrapper,
7475
ModelCallMessage,
7576
ModelYieldResultMessage,
7677
YieldBackgroundMessage,
@@ -1058,7 +1059,9 @@ def get_transformed_inputs(kwargs):
10581059

10591060
litellm.input_callback = [get_transformed_inputs]
10601061
# append_log(state, "Model Input", messages_to_str(model_input))
1061-
msg = yield from generate_client_response(state, concrete_block, model_input)
1062+
msg, raw_result = yield from generate_client_response(
1063+
state, concrete_block, model_input
1064+
)
10621065
if "input" in litellm_params:
10631066
append_log(state, "Model Input", litellm_params["input"])
10641067
else:
@@ -1069,6 +1072,8 @@ def get_transformed_inputs(kwargs):
10691072
result = msg["content"]
10701073
append_log(state, "Model Output", result)
10711074
trace = block.model_copy(update={"result": result, "trace": concrete_block})
1075+
if block.modelResponse is not None:
1076+
scope = scope | {block.modelResponse: raw_result}
10721077
return result, background, scope, trace
10731078
except Exception as exc:
10741079
message = f"Error during model call: {repr(exc)}"
@@ -1083,29 +1088,30 @@ def generate_client_response( # pylint: disable=too-many-arguments
10831088
state: InterpreterState,
10841089
block: BamModelBlock | LitellmModelBlock,
10851090
model_input: Messages,
1086-
) -> Generator[YieldMessage, Any, Message]:
1091+
) -> Generator[YieldMessage, Any, tuple[Message, Any]]:
1092+
raw_result = None
10871093
match state.batch:
10881094
case 0:
1089-
model_output = yield from generate_client_response_streaming(
1095+
model_output, raw_result = yield from generate_client_response_streaming(
10901096
state, block, model_input
10911097
)
10921098
case 1:
1093-
model_output = yield from generate_client_response_single(
1099+
model_output, raw_result = yield from generate_client_response_single(
10941100
state, block, model_input
10951101
)
10961102
case _:
10971103
model_output = yield from generate_client_response_batching(
10981104
state, block, model_input
10991105
)
1100-
return model_output
1106+
return model_output, raw_result
11011107

11021108

11031109
def generate_client_response_streaming(
11041110
state: InterpreterState,
11051111
block: BamModelBlock | LitellmModelBlock,
11061112
model_input: Messages,
1107-
) -> Generator[YieldMessage, Any, Message]:
1108-
msg_stream: Generator[Message, Any, None]
1113+
) -> Generator[YieldMessage, Any, tuple[Message, Any]]:
1114+
msg_stream: Generator[Message, Any, Any]
11091115
model_input_str = messages_to_str(block.model, model_input)
11101116
match block:
11111117
case BamModelBlock():
@@ -1127,7 +1133,8 @@ def generate_client_response_streaming(
11271133
assert False
11281134
complete_msg: Optional[Message] = None
11291135
role = None
1130-
for chunk in msg_stream:
1136+
wrapped_gen = GeneratorWrapper(msg_stream)
1137+
for chunk in wrapped_gen:
11311138
if state.yield_result:
11321139
yield ModelYieldResultMessage(chunk["content"])
11331140
if state.yield_background:
@@ -1139,9 +1146,12 @@ def generate_client_response_streaming(
11391146
chunk_role = chunk["role"]
11401147
if chunk_role is None or chunk_role == role:
11411148
complete_msg["content"] += chunk["content"]
1149+
raw_result = None
1150+
if block.modelResponse is not None:
1151+
raw_result = wrapped_gen.value
11421152
if complete_msg is None:
1143-
return Message(role=state.role, content="")
1144-
return complete_msg
1153+
return Message(role=state.role, content=""), raw_result
1154+
return complete_msg, raw_result
11451155

11461156

11471157
def litellm_parameters_to_dict(
@@ -1159,12 +1169,12 @@ def generate_client_response_single(
11591169
state: InterpreterState,
11601170
block: BamModelBlock | LitellmModelBlock,
11611171
model_input: Messages,
1162-
) -> Generator[YieldMessage, Any, Message]:
1172+
) -> Generator[YieldMessage, Any, tuple[Message, Any]]:
11631173
msg: Message
11641174
model_input_str = messages_to_str(block.model, model_input)
11651175
match block:
11661176
case BamModelBlock():
1167-
msg = BamModel.generate_text(
1177+
msg, raw_result = BamModel.generate_text(
11681178
model_id=block.model,
11691179
prompt_id=block.prompt_id,
11701180
model_input=model_input_str,
@@ -1173,7 +1183,7 @@ def generate_client_response_single(
11731183
data=block.data,
11741184
)
11751185
case LitellmModelBlock():
1176-
msg = LitellmModel.generate_text(
1186+
msg, raw_result = LitellmModel.generate_text(
11771187
model_id=block.model,
11781188
messages=model_input,
11791189
parameters=litellm_parameters_to_dict(block.parameters),
@@ -1182,7 +1192,7 @@ def generate_client_response_single(
11821192
yield YieldResultMessage(msg["content"])
11831193
if state.yield_background:
11841194
yield YieldBackgroundMessage([msg])
1185-
return msg
1195+
return msg, raw_result
11861196

11871197

11881198
def generate_client_response_batching( # pylint: disable=too-many-arguments

src/pdl/pdl_llms.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from typing import Any, Generator, Optional
23

34
import litellm
@@ -50,10 +51,11 @@ def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arg
5051
parameters: Optional[dict | BamTextGenerationParameters],
5152
moderations: Optional[BamModerationParameters],
5253
data: Optional[BamPromptTemplateData],
53-
) -> Message:
54+
) -> tuple[Message, Any]:
5455
client = BamModel.get_model()
5556
params = set_default_model_params(parameters)
5657
text = ""
58+
responses = []
5759
for response in client.text.generation.create(
5860
model_id=model_id,
5961
prompt_id=prompt_id,
@@ -63,10 +65,11 @@ def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arg
6365
data=data,
6466
):
6567
# XXX TODO: moderation
68+
responses.append(response)
6669
for result in response.results:
6770
if result.generated_text:
6871
text += result.generated_text
69-
return {"role": None, "content": text}
72+
return {"role": None, "content": text}, responses
7073

7174
@staticmethod
7275
def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positional-arguments
@@ -76,9 +79,10 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio
7679
parameters: Optional[dict | BamTextGenerationParameters],
7780
moderations: Optional[BamModerationParameters],
7881
data: Optional[BamPromptTemplateData],
79-
) -> Generator[Message, Any, None]:
82+
) -> Generator[Message, Any, Any]:
8083
client = BamModel.get_model()
8184
params = set_default_model_params(parameters)
85+
responses = []
8286
for response in client.text.generation.create_stream(
8387
model_id=model_id,
8488
prompt_id=prompt_id,
@@ -87,6 +91,7 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio
8791
moderations=moderations,
8892
data=data,
8993
):
94+
responses.append(json.loads(response.model_dump_json()))
9095
if response.results is None:
9196
# append_log(
9297
# state,
@@ -97,6 +102,7 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio
97102
for result in response.results:
98103
if result.generated_text:
99104
yield {"role": None, "content": result.generated_text}
105+
return responses
100106

101107
# @staticmethod
102108
# def generate_text_lazy( # pylint: disable=too-many-arguments
@@ -143,7 +149,7 @@ def generate_text(
143149
model_id: str,
144150
messages: list[Message],
145151
parameters: dict[str, Any],
146-
) -> Message:
152+
) -> tuple[Message, Any]:
147153
if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id:
148154
parameters = set_default_granite_model_parameters(model_id, parameters)
149155
if parameters.get("mock_response") is not None:
@@ -154,14 +160,17 @@ def generate_text(
154160
msg = response.choices[0].message # pyright: ignore
155161
if msg.content is None:
156162
assert False, "TODO" # XXX TODO XXX
157-
return {"role": msg.role, "content": msg.content}
163+
return {
164+
"role": msg.role,
165+
"content": msg.content,
166+
}, response.json() # pyright: ignore
158167

159168
@staticmethod
160169
def generate_text_stream(
161170
model_id: str,
162171
messages: list[Message],
163172
parameters: dict[str, Any],
164-
) -> Generator[Message, Any, None]:
173+
) -> Generator[Message, Any, Any]:
165174
if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id:
166175
parameters = set_default_granite_model_parameters(model_id, parameters)
167176
response = completion(
@@ -170,8 +179,11 @@ def generate_text_stream(
170179
stream=True,
171180
**parameters,
172181
)
182+
result = []
173183
for chunk in response:
184+
result.append(chunk.json()) # pyright: ignore
174185
msg = chunk.choices[0].delta # pyright: ignore
175186
if msg.content is None:
176187
break
177188
yield {"role": msg.role, "content": msg.content}
189+
return result

0 commit comments

Comments
 (0)