Skip to content

Commit 45561c6

Browse files
committed
cleanup
Signed-off-by: Mandana Vaziri <[email protected]>
1 parent 801e2a6 commit 45561c6

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

src/pdl/pdl_interpreter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,7 @@ def generate_client_response( # pylint: disable=too-many-arguments
10881088
state: InterpreterState,
10891089
block: BamModelBlock | LitellmModelBlock,
10901090
model_input: Messages,
1091-
) -> Generator[YieldMessage, Any, Message]:
1091+
) -> Generator[YieldMessage, Any, tuple[Message, Any]]:
10921092
raw_result = None
10931093
match state.batch:
10941094
case 0:
@@ -1111,7 +1111,7 @@ def generate_client_response_streaming(
11111111
block: BamModelBlock | LitellmModelBlock,
11121112
model_input: Messages,
11131113
) -> Generator[YieldMessage, Any, tuple[Message, Any]]:
1114-
msg_stream: Generator[Message, Any, None]
1114+
msg_stream: Generator[Message, Any, Any]
11151115
model_input_str = messages_to_str(block.model, model_input)
11161116
match block:
11171117
case BamModelBlock():
@@ -1150,7 +1150,7 @@ def generate_client_response_streaming(
11501150
if block.modelResponse is not None:
11511151
raw_result = wrapped_gen.value
11521152
if complete_msg is None:
1153-
return Message(role=state.role, content="")
1153+
return Message(role=state.role, content=""), raw_result
11541154
return complete_msg, raw_result
11551155

11561156

@@ -1169,7 +1169,7 @@ def generate_client_response_single(
11691169
state: InterpreterState,
11701170
block: BamModelBlock | LitellmModelBlock,
11711171
model_input: Messages,
1172-
) -> Generator[YieldMessage, Any, Message]:
1172+
) -> Generator[YieldMessage, Any, tuple[Message, Any]]:
11731173
msg: Message
11741174
model_input_str = messages_to_str(block.model, model_input)
11751175
match block:

src/pdl/pdl_llms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def generate_text( # pylint: disable=too-many-arguments,too-many-positional-arg
5151
parameters: Optional[dict | BamTextGenerationParameters],
5252
moderations: Optional[BamModerationParameters],
5353
data: Optional[BamPromptTemplateData],
54-
) -> tuple[Message, list[Any]]:
54+
) -> tuple[Message, Any]:
5555
client = BamModel.get_model()
5656
params = set_default_model_params(parameters)
5757
text = ""
@@ -79,7 +79,7 @@ def generate_text_stream( # pylint: disable=too-many-arguments,too-many-positio
7979
parameters: Optional[dict | BamTextGenerationParameters],
8080
moderations: Optional[BamModerationParameters],
8181
data: Optional[BamPromptTemplateData],
82-
) -> Generator[Message, Any, list[Any]]:
82+
) -> Generator[Message, Any, Any]:
8383
client = BamModel.get_model()
8484
params = set_default_model_params(parameters)
8585
responses = []
@@ -149,7 +149,7 @@ def generate_text(
149149
model_id: str,
150150
messages: list[Message],
151151
parameters: dict[str, Any],
152-
) -> Message:
152+
) -> tuple[Message, Any]:
153153
if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id:
154154
parameters = set_default_granite_model_parameters(model_id, parameters)
155155
if parameters.get("mock_response") is not None:
@@ -167,7 +167,7 @@ def generate_text_stream(
167167
model_id: str,
168168
messages: list[Message],
169169
parameters: dict[str, Any],
170-
) -> Generator[Message, Any, list[Any]]:
170+
) -> Generator[Message, Any, Any]:
171171
if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id:
172172
parameters = set_default_granite_model_parameters(model_id, parameters)
173173
response = completion(

0 commit comments

Comments
 (0)