Skip to content

Commit 4332002

Browse files
committed
model raw output
Signed-off-by: Mandana Vaziri <[email protected]>
1 parent 572373a commit 4332002

File tree

6 files changed

+64
-8
lines changed

6 files changed

+64
-8
lines changed

examples/hello/hello_model_raw.pdl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
description: Hello world
2+
text:
3+
- text: "Hello\n"
4+
contribute: [context]
5+
- model: ibm/granite-8b-code-instruct
6+
platform: bam
7+
def: output
8+
modelResponse: raw_output
9+
parameters:
10+
decoding_method: greedy
11+
stop_sequences: ["!"]
12+
include_stop_sequence: true
13+
return_options:
14+
generated_tokens: True
15+
token_logprobs: True
16+
contribute: []
17+
- ${ raw_output }

src/pdl/pdl-schema.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,18 @@
13111311
"default": null,
13121312
"title": "Trace"
13131313
},
1314+
"modelResponse": {
1315+
"anyOf": [
1316+
{
1317+
"type": "string"
1318+
},
1319+
{
1320+
"type": "null"
1321+
}
1322+
],
1323+
"default": null,
1324+
"title": "Modelresponse"
1325+
},
13141326
"platform": {
13151327
"const": "bam",
13161328
"enum": [
@@ -9132,6 +9144,18 @@
91329144
"default": null,
91339145
"title": "Trace"
91349146
},
9147+
"modelResponse": {
9148+
"anyOf": [
9149+
{
9150+
"type": "string"
9151+
},
9152+
{
9153+
"type": "null"
9154+
}
9155+
],
9156+
"default": null,
9157+
"title": "Modelresponse"
9158+
},
91359159
"platform": {
91369160
"const": "litellm",
91379161
"default": "litellm",

src/pdl/pdl_ast.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ class ModelBlock(Block):
256256
model: str | ExpressionType
257257
input: Optional["BlocksType"] = None
258258
trace: Optional["BlockType"] = None
259+
modelResponse: Optional[str] = None
259260

260261

261262
class BamModelBlock(ModelBlock):

src/pdl/pdl_dumper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def block_to_dict(block: pdl_ast.BlockType, json_compatible: bool) -> DumpedBloc
104104
d["data"] = block.data
105105
if block.constraints is not None:
106106
d["constraints"] = block.constraints
107+
if block.modelResponse is not None:
108+
d["modelResponse"] = block.modelResponse
107109
case LitellmModelBlock():
108110
d["platform"] = block.platform
109111
d["model"] = block.model
@@ -116,6 +118,8 @@ def block_to_dict(block: pdl_ast.BlockType, json_compatible: bool) -> DumpedBloc
116118
)
117119
else:
118120
d["parameters"] = block.parameters
121+
if block.modelResponse is not None:
122+
d["modelResponse"] = block.modelResponse
119123
case CodeBlock():
120124
d["lang"] = block.lang
121125
d["code"] = blocks_to_dict(block.code, json_compatible)

src/pdl/pdl_interpreter.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
YieldMessage,
7878
YieldResultMessage,
7979
schedule,
80+
GeneratorWrapper,
8081
)
8182
from .pdl_schema_validator import type_check_args, type_check_spec
8283
from .pdl_utils import messages_concat, messages_to_str, stringify
@@ -1058,7 +1059,7 @@ def get_transformed_inputs(kwargs):
10581059

10591060
litellm.input_callback = [get_transformed_inputs]
10601061
# append_log(state, "Model Input", messages_to_str(model_input))
1061-
msg = yield from generate_client_response(state, concrete_block, model_input)
1062+
msg, raw_result = yield from generate_client_response(state, concrete_block, model_input)
10621063
if "input" in litellm_params:
10631064
append_log(state, "Model Input", litellm_params["input"])
10641065
else:
@@ -1069,6 +1070,8 @@ def get_transformed_inputs(kwargs):
10691070
result = msg["content"]
10701071
append_log(state, "Model Output", result)
10711072
trace = block.model_copy(update={"result": result, "trace": concrete_block})
1073+
if block.modelResponse is not None:
1074+
scope = scope | {block.modelResponse: raw_result}
10721075
return result, background, scope, trace
10731076
except Exception as exc:
10741077
message = f"Error during model call: {repr(exc)}"
@@ -1086,7 +1089,7 @@ def generate_client_response( # pylint: disable=too-many-arguments
10861089
) -> Generator[YieldMessage, Any, Message]:
10871090
match state.batch:
10881091
case 0:
1089-
model_output = yield from generate_client_response_streaming(
1092+
model_output, raw_result = yield from generate_client_response_streaming(
10901093
state, block, model_input
10911094
)
10921095
case 1:
@@ -1097,14 +1100,14 @@ def generate_client_response( # pylint: disable=too-many-arguments
10971100
model_output = yield from generate_client_response_batching(
10981101
state, block, model_input
10991102
)
1100-
return model_output
1103+
return model_output, raw_result
11011104

11021105

11031106
def generate_client_response_streaming(
11041107
state: InterpreterState,
11051108
block: BamModelBlock | LitellmModelBlock,
11061109
model_input: Messages,
1107-
) -> Generator[YieldMessage, Any, Message]:
1110+
) -> Generator[YieldMessage, Any, tuple[Message, Any]]:
11081111
msg_stream: Generator[Message, Any, None]
11091112
model_input_str = messages_to_str(block.model, model_input)
11101113
match block:
@@ -1127,7 +1130,8 @@ def generate_client_response_streaming(
11271130
assert False
11281131
complete_msg: Optional[Message] = None
11291132
role = None
1130-
for chunk in msg_stream:
1133+
wrapped_gen = GeneratorWrapper(msg_stream)
1134+
for chunk in wrapped_gen:
11311135
if state.yield_result:
11321136
yield ModelYieldResultMessage(chunk["content"])
11331137
if state.yield_background:
@@ -1139,9 +1143,12 @@ def generate_client_response_streaming(
11391143
chunk_role = chunk["role"]
11401144
if chunk_role is None or chunk_role == role:
11411145
complete_msg["content"] += chunk["content"]
1146+
raw_result = None
1147+
if block.modelResponse is not None:
1148+
raw_result = wrapped_gen.value
11421149
if complete_msg is None:
11431150
return Message(role=state.role, content="")
1144-
return complete_msg
1151+
return complete_msg, raw_result
11451152

11461153

11471154
def litellm_parameters_to_dict(

src/pdl/pdl_llms.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Generator, Optional
22

33
import litellm
4+
import json
45
from dotenv import load_dotenv
56
from genai.client import Client as BamClient
67
from genai.credentials import Credentials as BamCredentials
@@ -76,9 +77,10 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio
7677
parameters: Optional[dict | BamTextGenerationParameters],
7778
moderations: Optional[BamModerationParameters],
7879
data: Optional[BamPromptTemplateData],
79-
) -> Generator[Message, Any, None]:
80+
) -> Generator[Message, Any, list[Any]]:
8081
client = BamModel.get_model()
8182
params = set_default_model_params(parameters)
83+
responses = []
8284
for response in client.text.generation.create_stream(
8385
model_id=model_id,
8486
prompt_id=prompt_id,
@@ -87,6 +89,7 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio
8789
moderations=moderations,
8890
data=data,
8991
):
92+
responses.append(json.loads(response.model_dump_json()))
9093
if response.results is None:
9194
# append_log(
9295
# state,
@@ -97,7 +100,7 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio
97100
for result in response.results:
98101
if result.generated_text:
99102
yield {"role": None, "content": result.generated_text}
100-
103+
return responses
101104
# @staticmethod
102105
# def generate_text_lazy( # pylint: disable=too-many-arguments
103106
# model_id: str,

0 commit comments

Comments
 (0)