Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pdl-live/src/pdl_ast.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2529,6 +2529,7 @@ export type Trace4 =
| ErrorBlock
| EmptyBlock
| null;
export type Modelresponse = string | null;
export type Platform = "bam";
export type PromptId = string | null;
export type Parameters =
Expand Down Expand Up @@ -2738,6 +2739,7 @@ export type Trace5 =
| ErrorBlock
| EmptyBlock
| null;
export type Modelresponse1 = string | null;
export type Platform1 = "litellm";
export type Parameters1 =
| LitellmParameters
Expand Down Expand Up @@ -3319,6 +3321,7 @@ export interface LitellmModelBlock {
model: unknown;
input?: Input1;
trace?: Trace5;
modelResponse?: Modelresponse1;
platform?: Platform1;
parameters?: Parameters1;
}
Expand Down Expand Up @@ -3401,6 +3404,7 @@ export interface BamModelBlock {
model: unknown;
input?: Input;
trace?: Trace4;
modelResponse?: Modelresponse;
platform: Platform;
prompt_id?: PromptId;
parameters?: Parameters;
Expand Down
24 changes: 24 additions & 0 deletions src/pdl/pdl-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,18 @@
"default": null,
"title": "Trace"
},
"modelResponse": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "Modelresponse"
},
"platform": {
"const": "bam",
"enum": [
Expand Down Expand Up @@ -9132,6 +9144,18 @@
"default": null,
"title": "Trace"
},
"modelResponse": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "Modelresponse"
},
"platform": {
"const": "litellm",
"default": "litellm",
Expand Down
1 change: 1 addition & 0 deletions src/pdl/pdl_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ class ModelBlock(Block):
model: str | ExpressionType
input: Optional["BlocksType"] = None
trace: Optional["BlockType"] = None
modelResponse: Optional[str] = None


class BamModelBlock(ModelBlock):
Expand Down
4 changes: 4 additions & 0 deletions src/pdl/pdl_dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def block_to_dict(block: pdl_ast.BlockType, json_compatible: bool) -> DumpedBloc
d["data"] = block.data
if block.constraints is not None:
d["constraints"] = block.constraints
if block.modelResponse is not None:
d["modelResponse"] = block.modelResponse
case LitellmModelBlock():
d["platform"] = block.platform
d["model"] = block.model
Expand All @@ -116,6 +118,8 @@ def block_to_dict(block: pdl_ast.BlockType, json_compatible: bool) -> DumpedBloc
)
else:
d["parameters"] = block.parameters
if block.modelResponse is not None:
d["modelResponse"] = block.modelResponse
case CodeBlock():
d["lang"] = block.lang
d["code"] = blocks_to_dict(block.code, json_compatible)
Expand Down
38 changes: 24 additions & 14 deletions src/pdl/pdl_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from .pdl_parser import PDLParseError, parse_file
from .pdl_scheduler import (
CodeYieldResultMessage,
GeneratorWrapper,
ModelCallMessage,
ModelYieldResultMessage,
YieldBackgroundMessage,
Expand Down Expand Up @@ -1058,7 +1059,9 @@ def get_transformed_inputs(kwargs):

litellm.input_callback = [get_transformed_inputs]
# append_log(state, "Model Input", messages_to_str(model_input))
msg = yield from generate_client_response(state, concrete_block, model_input)
msg, raw_result = yield from generate_client_response(
state, concrete_block, model_input
)
if "input" in litellm_params:
append_log(state, "Model Input", litellm_params["input"])
else:
Expand All @@ -1069,6 +1072,8 @@ def get_transformed_inputs(kwargs):
result = msg["content"]
append_log(state, "Model Output", result)
trace = block.model_copy(update={"result": result, "trace": concrete_block})
if block.modelResponse is not None:
scope = scope | {block.modelResponse: raw_result}
return result, background, scope, trace
except Exception as exc:
message = f"Error during model call: {repr(exc)}"
Expand All @@ -1083,29 +1088,30 @@ def generate_client_response( # pylint: disable=too-many-arguments
state: InterpreterState,
block: BamModelBlock | LitellmModelBlock,
model_input: Messages,
) -> Generator[YieldMessage, Any, Message]:
) -> Generator[YieldMessage, Any, tuple[Message, Any]]:
raw_result = None
match state.batch:
case 0:
model_output = yield from generate_client_response_streaming(
model_output, raw_result = yield from generate_client_response_streaming(
state, block, model_input
)
case 1:
model_output = yield from generate_client_response_single(
model_output, raw_result = yield from generate_client_response_single(
state, block, model_input
)
case _:
model_output = yield from generate_client_response_batching(
state, block, model_input
)
return model_output
return model_output, raw_result


def generate_client_response_streaming(
state: InterpreterState,
block: BamModelBlock | LitellmModelBlock,
model_input: Messages,
) -> Generator[YieldMessage, Any, Message]:
msg_stream: Generator[Message, Any, None]
) -> Generator[YieldMessage, Any, tuple[Message, Any]]:
msg_stream: Generator[Message, Any, Any]
model_input_str = messages_to_str(block.model, model_input)
match block:
case BamModelBlock():
Expand All @@ -1127,7 +1133,8 @@ def generate_client_response_streaming(
assert False
complete_msg: Optional[Message] = None
role = None
for chunk in msg_stream:
wrapped_gen = GeneratorWrapper(msg_stream)
for chunk in wrapped_gen:
if state.yield_result:
yield ModelYieldResultMessage(chunk["content"])
if state.yield_background:
Expand All @@ -1139,9 +1146,12 @@ def generate_client_response_streaming(
chunk_role = chunk["role"]
if chunk_role is None or chunk_role == role:
complete_msg["content"] += chunk["content"]
raw_result = None
if block.modelResponse is not None:
raw_result = wrapped_gen.value
if complete_msg is None:
return Message(role=state.role, content="")
return complete_msg
return Message(role=state.role, content=""), raw_result
return complete_msg, raw_result


def litellm_parameters_to_dict(
Expand All @@ -1159,12 +1169,12 @@ def generate_client_response_single(
state: InterpreterState,
block: BamModelBlock | LitellmModelBlock,
model_input: Messages,
) -> Generator[YieldMessage, Any, Message]:
) -> Generator[YieldMessage, Any, tuple[Message, Any]]:
msg: Message
model_input_str = messages_to_str(block.model, model_input)
match block:
case BamModelBlock():
msg = BamModel.generate_text(
msg, raw_result = BamModel.generate_text(
model_id=block.model,
prompt_id=block.prompt_id,
model_input=model_input_str,
Expand All @@ -1173,7 +1183,7 @@ def generate_client_response_single(
data=block.data,
)
case LitellmModelBlock():
msg = LitellmModel.generate_text(
msg, raw_result = LitellmModel.generate_text(
model_id=block.model,
messages=model_input,
parameters=litellm_parameters_to_dict(block.parameters),
Expand All @@ -1182,7 +1192,7 @@ def generate_client_response_single(
yield YieldResultMessage(msg["content"])
if state.yield_background:
yield YieldBackgroundMessage([msg])
return msg
return msg, raw_result


def generate_client_response_batching( # pylint: disable=too-many-arguments
Expand Down
24 changes: 18 additions & 6 deletions src/pdl/pdl_llms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Any, Generator, Optional

import litellm
Expand Down Expand Up @@ -50,10 +51,11 @@ def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arg
parameters: Optional[dict | BamTextGenerationParameters],
moderations: Optional[BamModerationParameters],
data: Optional[BamPromptTemplateData],
) -> Message:
) -> tuple[Message, Any]:
client = BamModel.get_model()
params = set_default_model_params(parameters)
text = ""
responses = []
for response in client.text.generation.create(
model_id=model_id,
prompt_id=prompt_id,
Expand All @@ -63,10 +65,11 @@ def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arg
data=data,
):
# XXX TODO: moderation
responses.append(response)
for result in response.results:
if result.generated_text:
text += result.generated_text
return {"role": None, "content": text}
return {"role": None, "content": text}, responses

@staticmethod
def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positional-arguments
Expand All @@ -76,9 +79,10 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio
parameters: Optional[dict | BamTextGenerationParameters],
moderations: Optional[BamModerationParameters],
data: Optional[BamPromptTemplateData],
) -> Generator[Message, Any, None]:
) -> Generator[Message, Any, Any]:
client = BamModel.get_model()
params = set_default_model_params(parameters)
responses = []
for response in client.text.generation.create_stream(
model_id=model_id,
prompt_id=prompt_id,
Expand All @@ -87,6 +91,7 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio
moderations=moderations,
data=data,
):
responses.append(json.loads(response.model_dump_json()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably could be

responses.append(response.model_dump())

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried but it didn't work.

if response.results is None:
# append_log(
# state,
Expand All @@ -97,6 +102,7 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio
for result in response.results:
if result.generated_text:
yield {"role": None, "content": result.generated_text}
return responses

# @staticmethod
# def generate_text_lazy( # pylint: disable=too-many-arguments
Expand Down Expand Up @@ -143,7 +149,7 @@ def generate_text(
model_id: str,
messages: list[Message],
parameters: dict[str, Any],
) -> Message:
) -> tuple[Message, Any]:
if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id:
parameters = set_default_granite_model_parameters(model_id, parameters)
if parameters.get("mock_response") is not None:
Expand All @@ -154,14 +160,17 @@ def generate_text(
msg = response.choices[0].message # pyright: ignore
if msg.content is None:
assert False, "TODO" # XXX TODO XXX
return {"role": msg.role, "content": msg.content}
return {
"role": msg.role,
"content": msg.content,
}, response.json() # pyright: ignore

@staticmethod
def generate_text_stream(
model_id: str,
messages: list[Message],
parameters: dict[str, Any],
) -> Generator[Message, Any, None]:
) -> Generator[Message, Any, Any]:
if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id:
parameters = set_default_granite_model_parameters(model_id, parameters)
response = completion(
Expand All @@ -170,8 +179,11 @@ def generate_text_stream(
stream=True,
**parameters,
)
result = []
for chunk in response:
result.append(chunk.json()) # pyright: ignore
msg = chunk.choices[0].delta # pyright: ignore
if msg.content is None:
break
yield {"role": msg.role, "content": msg.content}
return result