diff --git a/pdl-live/src/pdl_ast.d.ts b/pdl-live/src/pdl_ast.d.ts index 6724d767d..a7080591b 100644 --- a/pdl-live/src/pdl_ast.d.ts +++ b/pdl-live/src/pdl_ast.d.ts @@ -2529,6 +2529,7 @@ export type Trace4 = | ErrorBlock | EmptyBlock | null; +export type Modelresponse = string | null; export type Platform = "bam"; export type PromptId = string | null; export type Parameters = @@ -2738,6 +2739,7 @@ export type Trace5 = | ErrorBlock | EmptyBlock | null; +export type Modelresponse1 = string | null; export type Platform1 = "litellm"; export type Parameters1 = | LitellmParameters @@ -3319,6 +3321,7 @@ export interface LitellmModelBlock { model: unknown; input?: Input1; trace?: Trace5; + modelResponse?: Modelresponse1; platform?: Platform1; parameters?: Parameters1; } @@ -3401,6 +3404,7 @@ export interface BamModelBlock { model: unknown; input?: Input; trace?: Trace4; + modelResponse?: Modelresponse; platform: Platform; prompt_id?: PromptId; parameters?: Parameters; diff --git a/src/pdl/pdl-schema.json b/src/pdl/pdl-schema.json index 9b266132a..35447a176 100644 --- a/src/pdl/pdl-schema.json +++ b/src/pdl/pdl-schema.json @@ -1311,6 +1311,18 @@ "default": null, "title": "Trace" }, + "modelResponse": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Modelresponse" + }, "platform": { "const": "bam", "enum": [ @@ -9132,6 +9144,18 @@ "default": null, "title": "Trace" }, + "modelResponse": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Modelresponse" + }, "platform": { "const": "litellm", "default": "litellm", diff --git a/src/pdl/pdl_ast.py b/src/pdl/pdl_ast.py index 1a4557516..dde2534ed 100644 --- a/src/pdl/pdl_ast.py +++ b/src/pdl/pdl_ast.py @@ -256,6 +256,7 @@ class ModelBlock(Block): model: str | ExpressionType input: Optional["BlocksType"] = None trace: Optional["BlockType"] = None + modelResponse: Optional[str] = None class BamModelBlock(ModelBlock): diff --git a/src/pdl/pdl_dumper.py b/src/pdl/pdl_dumper.py index 8c6dd261f..3657cd9c8 100644 --- a/src/pdl/pdl_dumper.py +++ b/src/pdl/pdl_dumper.py @@ -104,6 +104,8 @@ def block_to_dict(block: pdl_ast.BlockType, json_compatible: bool) -> DumpedBloc d["data"] = block.data if block.constraints is not None: d["constraints"] = block.constraints + if block.modelResponse is not None: + d["modelResponse"] = block.modelResponse case LitellmModelBlock(): d["platform"] = block.platform d["model"] = block.model @@ -116,6 +118,8 @@ def block_to_dict(block: pdl_ast.BlockType, json_compatible: bool) -> DumpedBloc ) else: d["parameters"] = block.parameters + if block.modelResponse is not None: + d["modelResponse"] = block.modelResponse case CodeBlock(): d["lang"] = block.lang d["code"] = blocks_to_dict(block.code, json_compatible) diff --git a/src/pdl/pdl_interpreter.py b/src/pdl/pdl_interpreter.py index f9e919ac0..31b61be7f 100644 --- a/src/pdl/pdl_interpreter.py +++ b/src/pdl/pdl_interpreter.py @@ -71,6 +71,7 @@ from .pdl_parser import PDLParseError, parse_file from .pdl_scheduler import ( CodeYieldResultMessage, + GeneratorWrapper, ModelCallMessage, ModelYieldResultMessage, YieldBackgroundMessage, @@ -1058,7 +1059,9 @@ def get_transformed_inputs(kwargs): litellm.input_callback = [get_transformed_inputs] # append_log(state, "Model Input", messages_to_str(model_input)) - msg = yield from generate_client_response(state, concrete_block, model_input) + msg, raw_result = yield from generate_client_response( + state, concrete_block, model_input + ) if "input" in litellm_params: append_log(state, "Model Input", litellm_params["input"]) else: @@ -1069,6 +1072,8 @@ def get_transformed_inputs(kwargs): result = msg["content"] append_log(state, "Model Output", result) trace = block.model_copy(update={"result": result, "trace": concrete_block}) + if block.modelResponse is not None: + scope = scope | {block.modelResponse: raw_result} return result, background, scope, trace except Exception as exc: message = f"Error during model call: {repr(exc)}" @@ -1083,29 +1088,30 @@ def generate_client_response( # pylint: disable=too-many-arguments state: InterpreterState, block: BamModelBlock | LitellmModelBlock, model_input: Messages, -) -> Generator[YieldMessage, Any, Message]: +) -> Generator[YieldMessage, Any, tuple[Message, Any]]: + raw_result = None match state.batch: case 0: - model_output = yield from generate_client_response_streaming( + model_output, raw_result = yield from generate_client_response_streaming( state, block, model_input ) case 1: - model_output = yield from generate_client_response_single( + model_output, raw_result = yield from generate_client_response_single( state, block, model_input ) case _: model_output = yield from generate_client_response_batching( state, block, model_input ) - return model_output + return model_output, raw_result def generate_client_response_streaming( state: InterpreterState, block: BamModelBlock | LitellmModelBlock, model_input: Messages, -) -> Generator[YieldMessage, Any, Message]: - msg_stream: Generator[Message, Any, None] +) -> Generator[YieldMessage, Any, tuple[Message, Any]]: + msg_stream: Generator[Message, Any, Any] model_input_str = messages_to_str(block.model, model_input) match block: case BamModelBlock(): @@ -1127,7 +1133,8 @@ def generate_client_response_streaming( assert False complete_msg: Optional[Message] = None role = None - for chunk in msg_stream: + wrapped_gen = GeneratorWrapper(msg_stream) + for chunk in wrapped_gen: if state.yield_result: yield ModelYieldResultMessage(chunk["content"]) if state.yield_background: @@ -1139,9 +1146,12 @@ def generate_client_response_streaming( chunk_role = chunk["role"] if chunk_role is None or chunk_role == role: complete_msg["content"] += chunk["content"] + raw_result = None + if block.modelResponse is not None: + raw_result = wrapped_gen.value if complete_msg is None: - return Message(role=state.role, content="") - return complete_msg + return Message(role=state.role, content=""), raw_result + return complete_msg, raw_result def litellm_parameters_to_dict( @@ -1159,12 +1169,12 @@ def generate_client_response_single( state: InterpreterState, block: BamModelBlock | LitellmModelBlock, model_input: Messages, -) -> Generator[YieldMessage, Any, Message]: +) -> Generator[YieldMessage, Any, tuple[Message, Any]]: msg: Message model_input_str = messages_to_str(block.model, model_input) match block: case BamModelBlock(): - msg = BamModel.generate_text( + msg, raw_result = BamModel.generate_text( model_id=block.model, prompt_id=block.prompt_id, model_input=model_input_str, @@ -1173,7 +1183,7 @@ def generate_client_response_single( data=block.data, ) case LitellmModelBlock(): - msg = LitellmModel.generate_text( + msg, raw_result = LitellmModel.generate_text( model_id=block.model, messages=model_input, parameters=litellm_parameters_to_dict(block.parameters), @@ -1182,7 +1192,7 @@ def generate_client_response_single( yield YieldResultMessage(msg["content"]) if state.yield_background: yield YieldBackgroundMessage([msg]) - return msg + return msg, raw_result def generate_client_response_batching( # pylint: disable=too-many-arguments diff --git a/src/pdl/pdl_llms.py b/src/pdl/pdl_llms.py index 6c928891f..c2780207c 100644 --- a/src/pdl/pdl_llms.py +++ b/src/pdl/pdl_llms.py @@ -1,3 +1,4 @@ +import json from typing import Any, Generator, Optional import litellm @@ -50,10 +51,11 @@ def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arg parameters: Optional[dict | BamTextGenerationParameters], moderations: Optional[BamModerationParameters], data: Optional[BamPromptTemplateData], - ) -> Message: + ) -> tuple[Message, Any]: client = BamModel.get_model() params = set_default_model_params(parameters) text = "" + responses = [] for response in client.text.generation.create( model_id=model_id, prompt_id=prompt_id, @@ -63,10 +65,11 @@ def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arg data=data, ): # XXX TODO: moderation + responses.append(response) for result in response.results: if result.generated_text: text += result.generated_text - return {"role": None, "content": text} + return {"role": None, "content": text}, responses @staticmethod def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positional-arguments @@ -76,9 +79,10 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio parameters: Optional[dict | BamTextGenerationParameters], moderations: Optional[BamModerationParameters], data: Optional[BamPromptTemplateData], - ) -> Generator[Message, Any, None]: + ) -> Generator[Message, Any, Any]: client = BamModel.get_model() params = set_default_model_params(parameters) + responses = [] for response in client.text.generation.create_stream( model_id=model_id, prompt_id=prompt_id, @@ -87,6 +91,7 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio moderations=moderations, data=data, ): + responses.append(json.loads(response.model_dump_json())) if response.results is None: # append_log( # state, @@ -97,6 +102,7 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio for result in response.results: if result.generated_text: yield {"role": None, "content": result.generated_text} + return responses # @staticmethod # def generate_text_lazy( # pylint: disable=too-many-arguments @@ -143,7 +149,7 @@ def generate_text( model_id: str, messages: list[Message], parameters: dict[str, Any], - ) -> Message: + ) -> tuple[Message, Any]: if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id: parameters = set_default_granite_model_parameters(model_id, parameters) if parameters.get("mock_response") is not None: @@ -154,14 +160,17 @@ def generate_text( msg = response.choices[0].message # pyright: ignore if msg.content is None: assert False, "TODO" # XXX TODO XXX - return {"role": msg.role, "content": msg.content} + return { + "role": msg.role, + "content": msg.content, + }, response.json() # pyright: ignore @staticmethod def generate_text_stream( model_id: str, messages: list[Message], parameters: dict[str, Any], - ) -> Generator[Message, Any, None]: + ) -> Generator[Message, Any, Any]: if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id: parameters = set_default_granite_model_parameters(model_id, parameters) response = completion( @@ -170,8 +179,11 @@ def generate_text_stream( stream=True, **parameters, ) + result = [] for chunk in response: + result.append(chunk.json()) # pyright: ignore msg = chunk.choices[0].delta # pyright: ignore if msg.content is None: break yield {"role": msg.role, "content": msg.content} + return result