Skip to content

Commit 801e2a6

Browse files
committed
bug fixes
Signed-off-by: Mandana Vaziri <[email protected]>
1 parent a3eba5d commit 801e2a6

File tree

3 files changed

+20
-28
lines changed

3 files changed

+20
-28
lines changed

examples/hello/hello_model_raw.pdl

Lines changed: 0 additions & 17 deletions
This file was deleted.

src/pdl/pdl_interpreter.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@
7171
from .pdl_parser import PDLParseError, parse_file
7272
from .pdl_scheduler import (
7373
CodeYieldResultMessage,
74+
GeneratorWrapper,
7475
ModelCallMessage,
7576
ModelYieldResultMessage,
7677
YieldBackgroundMessage,
7778
YieldMessage,
7879
YieldResultMessage,
7980
schedule,
80-
GeneratorWrapper,
8181
)
8282
from .pdl_schema_validator import type_check_args, type_check_spec
8383
from .pdl_utils import messages_concat, messages_to_str, stringify
@@ -1059,7 +1059,9 @@ def get_transformed_inputs(kwargs):
10591059

10601060
litellm.input_callback = [get_transformed_inputs]
10611061
# append_log(state, "Model Input", messages_to_str(model_input))
1062-
msg, raw_result = yield from generate_client_response(state, concrete_block, model_input)
1062+
msg, raw_result = yield from generate_client_response(
1063+
state, concrete_block, model_input
1064+
)
10631065
if "input" in litellm_params:
10641066
append_log(state, "Model Input", litellm_params["input"])
10651067
else:
@@ -1087,13 +1089,14 @@ def generate_client_response( # pylint: disable=too-many-arguments
10871089
block: BamModelBlock | LitellmModelBlock,
10881090
model_input: Messages,
10891091
) -> Generator[YieldMessage, Any, Message]:
1092+
raw_result = None
10901093
match state.batch:
10911094
case 0:
10921095
model_output, raw_result = yield from generate_client_response_streaming(
10931096
state, block, model_input
10941097
)
10951098
case 1:
1096-
model_output = yield from generate_client_response_single(
1099+
model_output, raw_result = yield from generate_client_response_single(
10971100
state, block, model_input
10981101
)
10991102
case _:
@@ -1171,7 +1174,7 @@ def generate_client_response_single(
11711174
model_input_str = messages_to_str(block.model, model_input)
11721175
match block:
11731176
case BamModelBlock():
1174-
msg = BamModel.generate_text(
1177+
msg, raw_result = BamModel.generate_text(
11751178
model_id=block.model,
11761179
prompt_id=block.prompt_id,
11771180
model_input=model_input_str,
@@ -1180,7 +1183,7 @@ def generate_client_response_single(
11801183
data=block.data,
11811184
)
11821185
case LitellmModelBlock():
1183-
msg = LitellmModel.generate_text(
1186+
msg, raw_result = LitellmModel.generate_text(
11841187
model_id=block.model,
11851188
messages=model_input,
11861189
parameters=litellm_parameters_to_dict(block.parameters),
@@ -1189,7 +1192,7 @@ def generate_client_response_single(
11891192
yield YieldResultMessage(msg["content"])
11901193
if state.yield_background:
11911194
yield YieldBackgroundMessage([msg])
1192-
return msg
1195+
return msg, raw_result
11931196

11941197

11951198
def generate_client_response_batching( # pylint: disable=too-many-arguments

src/pdl/pdl_llms.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import json
12
from typing import Any, Generator, Optional
23

34
import litellm
4-
import json
55
from dotenv import load_dotenv
66
from genai.client import Client as BamClient
77
from genai.credentials import Credentials as BamCredentials
@@ -51,10 +51,11 @@ def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arg
5151
parameters: Optional[dict | BamTextGenerationParameters],
5252
moderations: Optional[BamModerationParameters],
5353
data: Optional[BamPromptTemplateData],
54-
) -> Message:
54+
) -> tuple[Message, list[Any]]:
5555
client = BamModel.get_model()
5656
params = set_default_model_params(parameters)
5757
text = ""
58+
responses = []
5859
for response in client.text.generation.create(
5960
model_id=model_id,
6061
prompt_id=prompt_id,
@@ -64,10 +65,11 @@ def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arg
6465
data=data,
6566
):
6667
# XXX TODO: moderation
68+
responses.append(response)
6769
for result in response.results:
6870
if result.generated_text:
6971
text += result.generated_text
70-
return {"role": None, "content": text}
72+
return {"role": None, "content": text}, responses
7173

7274
@staticmethod
7375
def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positional-arguments
@@ -101,6 +103,7 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio
101103
if result.generated_text:
102104
yield {"role": None, "content": result.generated_text}
103105
return responses
106+
104107
# @staticmethod
105108
# def generate_text_lazy( # pylint: disable=too-many-arguments
106109
# model_id: str,
@@ -157,14 +160,14 @@ def generate_text(
157160
msg = response.choices[0].message # pyright: ignore
158161
if msg.content is None:
159162
assert False, "TODO" # XXX TODO XXX
160-
return {"role": msg.role, "content": msg.content}
163+
return {"role": msg.role, "content": msg.content}, response.json()
161164

162165
@staticmethod
163166
def generate_text_stream(
164167
model_id: str,
165168
messages: list[Message],
166169
parameters: dict[str, Any],
167-
) -> Generator[Message, Any, None]:
170+
) -> Generator[Message, Any, list[Any]]:
168171
if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id:
169172
parameters = set_default_granite_model_parameters(model_id, parameters)
170173
response = completion(
@@ -173,8 +176,11 @@ def generate_text_stream(
173176
stream=True,
174177
**parameters,
175178
)
179+
result = []
176180
for chunk in response:
181+
result.append(chunk.json())
177182
msg = chunk.choices[0].delta # pyright: ignore
178183
if msg.content is None:
179184
break
180185
yield {"role": msg.role, "content": msg.content}
186+
return result

0 commit comments

Comments
 (0)