From 4332002ccac30c056b3d19f6a14594013302afa2 Mon Sep 17 00:00:00 2001 From: Mandana Vaziri Date: Tue, 29 Oct 2024 08:41:15 -0400 Subject: [PATCH 1/5] model raw output Signed-off-by: Mandana Vaziri --- examples/hello/hello_model_raw.pdl | 17 +++++++++++++++++ src/pdl/pdl-schema.json | 24 ++++++++++++++++++++++++ src/pdl/pdl_ast.py | 1 + src/pdl/pdl_dumper.py | 4 ++++ src/pdl/pdl_interpreter.py | 19 +++++++++++++------ src/pdl/pdl_llms.py | 7 +++++-- 6 files changed, 64 insertions(+), 8 deletions(-) create mode 100644 examples/hello/hello_model_raw.pdl diff --git a/examples/hello/hello_model_raw.pdl b/examples/hello/hello_model_raw.pdl new file mode 100644 index 000000000..911d3f69f --- /dev/null +++ b/examples/hello/hello_model_raw.pdl @@ -0,0 +1,17 @@ +description: Hello world +text: +- text: "Hello\n" + contribute: [context] +- model: ibm/granite-8b-code-instruct + platform: bam + def: output + modelResponse: raw_output + parameters: + decoding_method: greedy + stop_sequences: ["!"] + include_stop_sequence: true + return_options: + generated_tokens: True + token_logprobs: True + contribute: [] +- ${ raw_output } \ No newline at end of file diff --git a/src/pdl/pdl-schema.json b/src/pdl/pdl-schema.json index 9b266132a..35447a176 100644 --- a/src/pdl/pdl-schema.json +++ b/src/pdl/pdl-schema.json @@ -1311,6 +1311,18 @@ "default": null, "title": "Trace" }, + "modelResponse": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Modelresponse" + }, "platform": { "const": "bam", "enum": [ @@ -9132,6 +9144,18 @@ "default": null, "title": "Trace" }, + "modelResponse": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Modelresponse" + }, "platform": { "const": "litellm", "default": "litellm", diff --git a/src/pdl/pdl_ast.py b/src/pdl/pdl_ast.py index 1a4557516..dde2534ed 100644 --- a/src/pdl/pdl_ast.py +++ b/src/pdl/pdl_ast.py @@ -256,6 +256,7 @@ class ModelBlock(Block): model: str | ExpressionType input: Optional["BlocksType"] = None trace: Optional["BlockType"] = None + modelResponse: Optional[str] = None class BamModelBlock(ModelBlock): diff --git a/src/pdl/pdl_dumper.py b/src/pdl/pdl_dumper.py index 8c6dd261f..3657cd9c8 100644 --- a/src/pdl/pdl_dumper.py +++ b/src/pdl/pdl_dumper.py @@ -104,6 +104,8 @@ def block_to_dict(block: pdl_ast.BlockType, json_compatible: bool) -> DumpedBloc d["data"] = block.data if block.constraints is not None: d["constraints"] = block.constraints + if block.modelResponse is not None: + d["modelResponse"] = block.modelResponse case LitellmModelBlock(): d["platform"] = block.platform d["model"] = block.model @@ -116,6 +118,8 @@ def block_to_dict(block: pdl_ast.BlockType, json_compatible: bool) -> DumpedBloc ) else: d["parameters"] = block.parameters + if block.modelResponse is not None: + d["modelResponse"] = block.modelResponse case CodeBlock(): d["lang"] = block.lang d["code"] = blocks_to_dict(block.code, json_compatible) diff --git a/src/pdl/pdl_interpreter.py b/src/pdl/pdl_interpreter.py index f9e919ac0..521665a38 100644 --- a/src/pdl/pdl_interpreter.py +++ b/src/pdl/pdl_interpreter.py @@ -77,6 +77,7 @@ YieldMessage, YieldResultMessage, schedule, + GeneratorWrapper, ) from .pdl_schema_validator import type_check_args, type_check_spec from .pdl_utils import messages_concat, messages_to_str, stringify @@ -1058,7 +1059,7 @@ def get_transformed_inputs(kwargs): litellm.input_callback = [get_transformed_inputs] # append_log(state, "Model Input", messages_to_str(model_input)) - msg = yield from generate_client_response(state, concrete_block, model_input) + msg, raw_result = yield from generate_client_response(state, concrete_block, model_input) if "input" in litellm_params: append_log(state, "Model Input", litellm_params["input"]) else: @@ -1069,6 +1070,8 @@ def get_transformed_inputs(kwargs): result = msg["content"] append_log(state, "Model Output", result) trace = block.model_copy(update={"result": result, "trace": concrete_block}) + if block.modelResponse is not None: + scope = scope | {block.modelResponse: raw_result} return result, background, scope, trace except Exception as exc: message = f"Error during model call: {repr(exc)}" @@ -1086,7 +1089,7 @@ def generate_client_response( # pylint: disable=too-many-arguments ) -> Generator[YieldMessage, Any, Message]: match state.batch: case 0: - model_output = yield from generate_client_response_streaming( + model_output, raw_result = yield from generate_client_response_streaming( state, block, model_input ) case 1: @@ -1097,14 +1100,14 @@ def generate_client_response( # pylint: disable=too-many-arguments model_output = yield from generate_client_response_batching( state, block, model_input ) - return model_output + return model_output, raw_result def generate_client_response_streaming( state: InterpreterState, block: BamModelBlock | LitellmModelBlock, model_input: Messages, -) -> Generator[YieldMessage, Any, Message]: +) -> Generator[YieldMessage, Any, tuple[Message, Any]]: msg_stream: Generator[Message, Any, None] model_input_str = messages_to_str(block.model, model_input) match block: @@ -1127,7 +1130,8 @@ def generate_client_response_streaming( assert False complete_msg: Optional[Message] = None role = None - for chunk in msg_stream: + wrapped_gen = GeneratorWrapper(msg_stream) + for chunk in wrapped_gen: if state.yield_result: yield ModelYieldResultMessage(chunk["content"]) if state.yield_background: @@ -1139,9 +1143,12 @@ def generate_client_response_streaming( chunk_role = chunk["role"] if chunk_role is None or chunk_role == role: complete_msg["content"] += chunk["content"] + raw_result = None + if block.modelResponse is not None: + raw_result = wrapped_gen.value if complete_msg is None: return Message(role=state.role, content="") - return complete_msg + return complete_msg, raw_result def litellm_parameters_to_dict( diff --git a/src/pdl/pdl_llms.py b/src/pdl/pdl_llms.py index 6c928891f..2992af6de 100644 --- a/src/pdl/pdl_llms.py +++ b/src/pdl/pdl_llms.py @@ -1,6 +1,7 @@ from typing import Any, Generator, Optional import litellm +import json from dotenv import load_dotenv from genai.client import Client as BamClient from genai.credentials import Credentials as BamCredentials @@ -76,9 +77,10 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio parameters: Optional[dict | BamTextGenerationParameters], moderations: Optional[BamModerationParameters], data: Optional[BamPromptTemplateData], - ) -> Generator[Message, Any, None]: + ) -> Generator[Message, Any, list[Any]]: client = BamModel.get_model() params = set_default_model_params(parameters) + responses = [] for response in client.text.generation.create_stream( model_id=model_id, prompt_id=prompt_id, @@ -87,6 +89,7 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio moderations=moderations, data=data, ): + responses.append(json.loads(response.model_dump_json())) if response.results is None: # append_log( # state, @@ -97,7 +100,7 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio for result in response.results: if result.generated_text: yield {"role": None, "content": result.generated_text} - + return responses # @staticmethod # def generate_text_lazy( # pylint: disable=too-many-arguments # model_id: str, From 801e2a6bbdcaf7ff42c29ea9be583a8068d2000c Mon Sep 17 00:00:00 2001 From: Mandana Vaziri Date: Tue, 29 Oct 2024 10:17:59 -0400 Subject: [PATCH 2/5] bug fixes Signed-off-by: Mandana Vaziri --- examples/hello/hello_model_raw.pdl | 17 ----------------- src/pdl/pdl_interpreter.py | 15 +++++++++------ src/pdl/pdl_llms.py | 16 +++++++++++----- 3 files changed, 20 insertions(+), 28 deletions(-) delete mode 100644 examples/hello/hello_model_raw.pdl diff --git a/examples/hello/hello_model_raw.pdl b/examples/hello/hello_model_raw.pdl deleted file mode 100644 index 911d3f69f..000000000 --- a/examples/hello/hello_model_raw.pdl +++ /dev/null @@ -1,17 +0,0 @@ -description: Hello world -text: -- text: "Hello\n" - contribute: [context] -- model: ibm/granite-8b-code-instruct - platform: bam - def: output - modelResponse: raw_output - parameters: - decoding_method: greedy - stop_sequences: ["!"] - include_stop_sequence: true - return_options: - generated_tokens: True - token_logprobs: True - contribute: [] -- ${ raw_output } \ No newline at end of file diff --git a/src/pdl/pdl_interpreter.py b/src/pdl/pdl_interpreter.py index 521665a38..f9211e514 100644 --- a/src/pdl/pdl_interpreter.py +++ b/src/pdl/pdl_interpreter.py @@ -71,13 +71,13 @@ from .pdl_parser import PDLParseError, parse_file from .pdl_scheduler import ( CodeYieldResultMessage, + GeneratorWrapper, ModelCallMessage, ModelYieldResultMessage, YieldBackgroundMessage, YieldMessage, YieldResultMessage, schedule, - GeneratorWrapper, ) from .pdl_schema_validator import type_check_args, type_check_spec from .pdl_utils import messages_concat, messages_to_str, stringify @@ -1059,7 +1059,9 @@ def get_transformed_inputs(kwargs): litellm.input_callback = [get_transformed_inputs] # append_log(state, "Model Input", messages_to_str(model_input)) - msg, raw_result = yield from generate_client_response(state, concrete_block, model_input) + msg, raw_result = yield from generate_client_response( + state, concrete_block, model_input + ) if "input" in litellm_params: append_log(state, "Model Input", litellm_params["input"]) else: @@ -1087,13 +1089,14 @@ def generate_client_response( # pylint: disable=too-many-arguments block: BamModelBlock | LitellmModelBlock, model_input: Messages, ) -> Generator[YieldMessage, Any, Message]: + raw_result = None match state.batch: case 0: model_output, raw_result = yield from generate_client_response_streaming( state, block, model_input ) case 1: - model_output = yield from generate_client_response_single( + model_output, raw_result = yield from generate_client_response_single( state, block, model_input ) case _: @@ -1171,7 +1174,7 @@ def generate_client_response_single( model_input_str = messages_to_str(block.model, model_input) match block: case BamModelBlock(): - msg = BamModel.generate_text( + msg, raw_result = BamModel.generate_text( model_id=block.model, prompt_id=block.prompt_id, model_input=model_input_str, @@ -1180,7 +1183,7 @@ def generate_client_response_single( data=block.data, ) case LitellmModelBlock(): - msg = LitellmModel.generate_text( + msg, raw_result = LitellmModel.generate_text( model_id=block.model, messages=model_input, parameters=litellm_parameters_to_dict(block.parameters), @@ -1189,7 +1192,7 @@ def generate_client_response_single( yield YieldResultMessage(msg["content"]) if state.yield_background: yield YieldBackgroundMessage([msg]) - return msg + return msg, raw_result def generate_client_response_batching( # pylint: disable=too-many-arguments diff --git a/src/pdl/pdl_llms.py b/src/pdl/pdl_llms.py index 2992af6de..af9e89774 100644 --- a/src/pdl/pdl_llms.py +++ b/src/pdl/pdl_llms.py @@ -1,7 +1,7 @@ +import json from typing import Any, Generator, Optional import litellm -import json from dotenv import load_dotenv from genai.client import Client as BamClient from genai.credentials import Credentials as BamCredentials @@ -51,10 +51,11 @@ def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arg parameters: Optional[dict | BamTextGenerationParameters], moderations: Optional[BamModerationParameters], data: Optional[BamPromptTemplateData], - ) -> Message: + ) -> tuple[Message, list[Any]]: client = BamModel.get_model() params = set_default_model_params(parameters) text = "" + responses = [] for response in client.text.generation.create( model_id=model_id, prompt_id=prompt_id, @@ -64,10 +65,11 @@ def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arg data=data, ): # XXX TODO: moderation + responses.append(response) for result in response.results: if result.generated_text: text += result.generated_text - return {"role": None, "content": text} + return {"role": None, "content": text}, responses @staticmethod def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positional-arguments @@ -101,6 +103,7 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio if result.generated_text: yield {"role": None, "content": result.generated_text} return responses + # @staticmethod # def generate_text_lazy( # pylint: disable=too-many-arguments # model_id: str, @@ -157,14 +160,14 @@ def generate_text( msg = response.choices[0].message # pyright: ignore if msg.content is None: assert False, "TODO" # XXX TODO XXX - return {"role": msg.role, "content": msg.content} + return {"role": msg.role, "content": msg.content}, response.json() @staticmethod def generate_text_stream( model_id: str, messages: list[Message], parameters: dict[str, Any], - ) -> Generator[Message, Any, None]: + ) -> Generator[Message, Any, list[Any]]: if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id: parameters = set_default_granite_model_parameters(model_id, parameters) response = completion( @@ -173,8 +176,11 @@ def generate_text_stream( stream=True, **parameters, ) + result = [] for chunk in response: + result.append(chunk.json()) msg = chunk.choices[0].delta # pyright: ignore if msg.content is None: break yield {"role": msg.role, "content": msg.content} + return result From 45561c6c94763b87a9cf53c6ad9ab164563bfc96 Mon Sep 17 00:00:00 2001 From: Mandana Vaziri Date: Tue, 29 Oct 2024 10:25:25 -0400 Subject: [PATCH 3/5] cleanup Signed-off-by: Mandana Vaziri --- src/pdl/pdl_interpreter.py | 8 ++++---- src/pdl/pdl_llms.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/pdl/pdl_interpreter.py b/src/pdl/pdl_interpreter.py index f9211e514..31b61be7f 100644 --- a/src/pdl/pdl_interpreter.py +++ b/src/pdl/pdl_interpreter.py @@ -1088,7 +1088,7 @@ def generate_client_response( # pylint: disable=too-many-arguments state: InterpreterState, block: BamModelBlock | LitellmModelBlock, model_input: Messages, -) -> Generator[YieldMessage, Any, Message]: +) -> Generator[YieldMessage, Any, tuple[Message, Any]]: raw_result = None match state.batch: case 0: @@ -1111,7 +1111,7 @@ def generate_client_response_streaming( block: BamModelBlock | LitellmModelBlock, model_input: Messages, ) -> Generator[YieldMessage, Any, tuple[Message, Any]]: - msg_stream: Generator[Message, Any, None] + msg_stream: Generator[Message, Any, Any] model_input_str = messages_to_str(block.model, model_input) match block: case BamModelBlock(): @@ -1150,7 +1150,7 @@ def generate_client_response_streaming( if block.modelResponse is not None: raw_result = wrapped_gen.value if complete_msg is None: - return Message(role=state.role, content="") + return Message(role=state.role, content=""), raw_result return complete_msg, raw_result @@ -1169,7 +1169,7 @@ def generate_client_response_single( state: InterpreterState, block: BamModelBlock | LitellmModelBlock, model_input: Messages, -) -> Generator[YieldMessage, Any, Message]: +) -> Generator[YieldMessage, Any, tuple[Message, Any]]: msg: Message model_input_str = messages_to_str(block.model, model_input) match block: diff --git a/src/pdl/pdl_llms.py b/src/pdl/pdl_llms.py index af9e89774..9050fcd64 100644 --- a/src/pdl/pdl_llms.py +++ b/src/pdl/pdl_llms.py @@ -51,7 +51,7 @@ def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arg parameters: Optional[dict | BamTextGenerationParameters], moderations: Optional[BamModerationParameters], data: Optional[BamPromptTemplateData], - ) -> tuple[Message, list[Any]]: + ) -> tuple[Message, Any]: client = BamModel.get_model() params = set_default_model_params(parameters) text = "" @@ -79,7 +79,7 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio parameters: Optional[dict | BamTextGenerationParameters], moderations: Optional[BamModerationParameters], data: Optional[BamPromptTemplateData], - ) -> Generator[Message, Any, list[Any]]: + ) -> Generator[Message, Any, Any]: client = BamModel.get_model() params = set_default_model_params(parameters) responses = [] @@ -149,7 +149,7 @@ def generate_text( model_id: str, messages: list[Message], parameters: dict[str, Any], - ) -> Message: + ) -> tuple[Message, Any]: if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id: parameters = set_default_granite_model_parameters(model_id, parameters) if parameters.get("mock_response") is not None: @@ -167,7 +167,7 @@ def generate_text_stream( model_id: str, messages: list[Message], parameters: dict[str, Any], - ) -> Generator[Message, Any, list[Any]]: + ) -> Generator[Message, Any, Any]: if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id: parameters = set_default_granite_model_parameters(model_id, parameters) response = completion( From c69113a761ef6c4cb5af893fe056a1b815e3f1d3 Mon Sep 17 00:00:00 2001 From: Mandana Vaziri Date: Tue, 29 Oct 2024 10:40:19 -0400 Subject: [PATCH 4/5] cleanup Signed-off-by: Mandana Vaziri --- src/pdl/pdl_llms.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/pdl/pdl_llms.py b/src/pdl/pdl_llms.py index 9050fcd64..c2780207c 100644 --- a/src/pdl/pdl_llms.py +++ b/src/pdl/pdl_llms.py @@ -160,7 +160,10 @@ def generate_text( msg = response.choices[0].message # pyright: ignore if msg.content is None: assert False, "TODO" # XXX TODO XXX - return {"role": msg.role, "content": msg.content}, response.json() + return { + "role": msg.role, + "content": msg.content, + }, response.json() # pyright: ignore @staticmethod def generate_text_stream( @@ -178,7 +181,7 @@ def generate_text_stream( ) result = [] for chunk in response: - result.append(chunk.json()) + result.append(chunk.json()) # pyright: ignore msg = chunk.choices[0].delta # pyright: ignore if msg.content is None: break From 8c1d2cb6f1f0b19a95ca6d1000a69462b4688b47 Mon Sep 17 00:00:00 2001 From: Mandana Vaziri Date: Tue, 29 Oct 2024 12:53:59 -0400 Subject: [PATCH 5/5] update to live viewer Signed-off-by: Mandana Vaziri --- pdl-live/src/pdl_ast.d.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pdl-live/src/pdl_ast.d.ts b/pdl-live/src/pdl_ast.d.ts index 6724d767d..a7080591b 100644 --- a/pdl-live/src/pdl_ast.d.ts +++ b/pdl-live/src/pdl_ast.d.ts @@ -2529,6 +2529,7 @@ export type Trace4 = | ErrorBlock | EmptyBlock | null; +export type Modelresponse = string | null; export type Platform = "bam"; export type PromptId = string | null; export type Parameters = @@ -2738,6 +2739,7 @@ export type Trace5 = | ErrorBlock | EmptyBlock | null; +export type Modelresponse1 = string | null; export type Platform1 = "litellm"; export type Parameters1 = | LitellmParameters @@ -3319,6 +3321,7 @@ export interface LitellmModelBlock { model: unknown; input?: Input1; trace?: Trace5; + modelResponse?: Modelresponse1; platform?: Platform1; parameters?: Parameters1; } @@ -3401,6 +3404,7 @@ export interface BamModelBlock { model: unknown; input?: Input; trace?: Trace4; + modelResponse?: Modelresponse; platform: Platform; prompt_id?: PromptId; parameters?: Parameters;