Skip to content

Commit 2cc8889

Browse files
authored
Fix for llm message post-processing (#1106)
* fix for llm message post-processing Signed-off-by: Mandana Vaziri <[email protected]>
1 parent f807142 commit 2cc8889

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

src/pdl/pdl_llms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from .pdl_lazy import PdlConst, PdlLazy, lazy_apply
1919
from .pdl_schema_utils import pdltype_to_jsonschema
20-
from .pdl_utils import remove_none_values_from_message
20+
from .pdl_utils import message_post_processing
2121

2222
# Load environment variables
2323
load_dotenv()
@@ -47,7 +47,7 @@ async def async_generate_text(
4747
if msg.role is None:
4848
msg.role = "assistant"
4949
return (
50-
remove_none_values_from_message(msg.json()),
50+
message_post_processing(msg.json()),
5151
response.json(), # pyright: ignore
5252
)
5353
except httpx.RequestError as exc:
@@ -161,7 +161,7 @@ def generate_text_stream(
161161
msg = chunk.choices[0].delta # pyright: ignore
162162
if msg.role is None:
163163
msg.role = "assistant"
164-
yield remove_none_values_from_message(msg.model_dump())
164+
yield message_post_processing(msg.model_dump())
165165
return result
166166

167167

src/pdl/pdl_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,23 @@ def get_contribute_value(
108108
return None
109109

110110

111-
def remove_none_values_from_message(message: dict) -> dict[str, Any]:
111+
def message_post_processing(message: dict) -> dict[str, Any]:
112112
ret = {}
113113
for key, value in message.items():
114-
if key == "content":
114+
if key == "content" and value is not None:
115+
ret[key] = value
116+
elif (
117+
key == "reasoning_content" and value is not None
118+
): # TODO: replacing reasoning_content with content here
119+
key = "content"
115120
ret[key] = value
116121
if value is not None:
117122
if isinstance(value, dict):
118-
ret[key] = remove_none_values_from_message(value)
123+
ret[key] = message_post_processing(value)
119124
else:
120125
ret[key] = value
126+
if "content" not in ret:
127+
ret["content"] = ""
121128
return ret
122129

123130

0 commit comments

Comments
 (0)