Skip to content

Commit 4a4759e

Browse files
authored
remove None valued parameters when calling LLMs (#212)
* remove None valued parameters when calling LLMs Signed-off-by: Mandana Vaziri <[email protected]>
1 parent ce06098 commit 4a4759e

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

src/pdl/pdl_interpreter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,7 @@ def get_transformed_inputs(kwargs):
11251125

11261126
litellm.input_callback = [get_transformed_inputs]
11271127
# append_log(state, "Model Input", messages_to_str(model_input))
1128+
11281129
msg, raw_result = yield from generate_client_response(
11291130
state, concrete_block, model_input
11301131
)

src/pdl/pdl_llms.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
set_default_model_params,
1717
set_structured_decoding_parameters,
1818
)
19+
from .pdl_utils import remove_none_values_from_message
1920

2021
# Load environment variables
2122
load_dotenv()
@@ -165,7 +166,10 @@ def generate_text(
165166
msg = response.choices[0].message # pyright: ignore
166167
if msg.role is None:
167168
msg.role = "assistant"
168-
return msg.json(), response.json() # pyright: ignore
169+
return (
170+
remove_none_values_from_message(msg.json()),
171+
response.json(), # pyright: ignore
172+
)
169173

170174
@staticmethod
171175
def generate_text_stream(
@@ -191,5 +195,5 @@ def generate_text_stream(
191195
msg = chunk.choices[0].delta # pyright: ignore
192196
if msg.role is None:
193197
msg.role = "assistant"
194-
yield msg.model_dump()
198+
yield remove_none_values_from_message(msg.model_dump())
195199
return result

src/pdl/pdl_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Sequence
2+
from typing import Any, Sequence
33

44
from .pdl_ast import ContributeTarget, ContributeValue, FunctionBlock, Message, Messages
55

@@ -70,3 +70,16 @@ def simple_message(message: Message) -> bool:
7070
if message.keys() == {"role", "content"} and message["content"] is not None:
7171
return True
7272
return False
73+
74+
75+
def remove_none_values_from_message(message: Any) -> dict[str, Any]:
76+
ret = {}
77+
for key, value in message.items():
78+
if key == "content":
79+
ret[key] = value
80+
if value is not None:
81+
if isinstance(value, dict):
82+
ret[key] = remove_none_values_from_message(value)
83+
else:
84+
ret[key] = value
85+
return ret

0 commit comments

Comments
 (0)