Skip to content

Commit 82b0b8a

Browse files
mirodrr2mirodrr
andauthored
fix bug around system messages being discarded (#102)
Co-authored-by: michael rodriguez <[email protected]>
1 parent 6bab758 commit 82b0b8a

File tree

2 files changed

+34
-43
lines changed

2 files changed

+34
-43
lines changed

middleware/app.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,7 @@ async def proxy_request(request: Request):
890890
)
891891
provided_hash = hash_api_key(api_key)
892892

893+
# Prepare or load chat_history
893894
if history_enabled:
894895
if session_id is not None:
895896
# Retrieve or verify existing session
@@ -902,44 +903,33 @@ async def proxy_request(request: Request):
902903
"error": "Unauthorized: API key does not match session owner"
903904
},
904905
)
905-
chat_history = (
906-
session_data["chat_history"]
907-
if session_data["chat_history"]
908-
else []
909-
)
906+
chat_history = session_data["chat_history"] or []
910907
else:
911908
chat_history = []
912909
create_chat_history(session_id, chat_history, provided_hash)
913910
else:
914-
# No session_id provided but enable_history = True, so create a new session
911+
# No session_id but enable_history = True, so create a new session
915912
session_id = str(uuid.uuid4())
916913
chat_history = []
917914
create_chat_history(session_id, chat_history, provided_hash)
918-
919-
# Merge incoming user messages into chat history
920-
user_messages_this_round = [
921-
m for m in data.get("messages", []) if m["role"] == "user"
922-
]
923-
if user_messages_this_round:
924-
chat_history.append(user_messages_this_round[-1])
925-
926-
# Overwrite data["messages"] with chat_history for the LLM request
927-
data["messages"] = chat_history
928915
else:
929-
# History is disabled and no valid session_id is provided.
930-
# Pass messages through as-is.
916+
# History not enabled: start with empty
931917
chat_history = []
932918

933-
# Merge incoming user messages into chat history
934-
user_messages_this_round = [
935-
m for m in data.get("messages", []) if m["role"] == "user"
936-
]
937-
if user_messages_this_round:
938-
chat_history.append(user_messages_this_round[-1])
919+
# Merge incoming system/user messages into chat_history in original order
920+
# (We generally skip adding "assistant" messages from the request side,
921+
# because those come from the model, not from the user.)
922+
new_messages = data.get("messages", [])
923+
for msg in new_messages:
924+
if msg["role"] in ["system", "user"]:
925+
chat_history.append(msg)
939926

927+
# Now data["messages"] should be the entire conversation the model sees
940928
data["messages"] = chat_history
941929

942-
# Check for prompt ARN logic
930+
# ---------------------------------------------------------------------
931+
# Handle optional "Bedrock Prompt" logic (unchanged from your snippet):
932+
# ---------------------------------------------------------------------
943933
model_id = data.get("model")
944934
prompt_variables = data.pop("promptVariables", {})
945935
final_prompt_text = None
@@ -968,15 +958,14 @@ async def proxy_request(request: Request):
968958
if final_prompt_text:
969959
data["messages"] = [{"role": "user", "content": final_prompt_text}]
970960

971-
# client = AsyncOpenAI(api_key=api_key, base_url=LITELLM_ENDPOINT)
972-
961+
# ---------------------------------------------------------------------
962+
# Stream vs. Non-Stream logic
963+
# ---------------------------------------------------------------------
973964
if is_streaming:
974-
# print(f"streaming")
975965
return await get_chat_stream(
976966
api_key, data, session_id, chat_history, history_enabled
977967
)
978968
else:
979-
# print(f"not streaming")
980969
headers = {
981970
"Content-Type": "application/json",
982971
"Authorization": f"Bearer {api_key}",
@@ -985,14 +974,14 @@ async def proxy_request(request: Request):
985974
async with session.post(
986975
f"{LITELLM_ENDPOINT}/v1/chat/completions",
987976
headers=headers,
988-
json=data, # Sending the data in the body
977+
json=data,
989978
) as resp:
990-
# Parse the response JSON
991979
response_headers = dict(resp.headers)
992-
response_headers.pop("Content-Length")
993-
# print(response_headers)
980+
# Avoid passing through invalid content-length
981+
response_headers.pop("Content-Length", None)
994982
response_dict = await resp.json()
995983

984+
# If there's a response from the assistant, save it to history
996985
if response_dict.get("choices"):
997986
assistant_message = response_dict["choices"][0]["message"]
998987
if history_enabled:
@@ -1001,6 +990,7 @@ async def proxy_request(request: Request):
1001990
)
1002991
update_chat_history(session_id, chat_history)
1003992

993+
# Return session_id in the response if we have one
1004994
if session_id:
1005995
response_dict["session_id"] = session_id
1006996

tests/openai_chat_test_file.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
base_url = os.getenv("API_ENDPOINT")
1212
api_key = os.getenv("API_KEY")
1313
model_id = os.getenv("MODEL_ID")
14+
print(f'base_url: {base_url} api_key: {api_key} model_id: {model_id}')
1415
client = OpenAI(base_url=base_url, api_key=api_key)
1516
managed_prompt_arn = os.getenv("MANAGED_PROMPT_ARN")
1617
managed_prompt_variable_name = os.getenv("MANAGED_PROMPT_VARIABLE_NAME")
@@ -23,7 +24,7 @@
2324

2425
async def stream_completion(
2526
prompt: str,
26-
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0",
27+
model: str = "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
2728
extra_body: Dict[str, Any] = None,
2829
) -> AsyncGenerator[Tuple[str, str], None]:
2930
"""
@@ -55,8 +56,8 @@ async def stream_completion(
5556

5657

5758
def get_completion(
58-
prompt: str,
59-
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0",
59+
messages: list,
60+
model: str = "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
6061
extra_body: Dict[str, Any] = None,
6162
) -> Tuple[str, str]:
6263
"""
@@ -68,7 +69,7 @@ def get_completion(
6869

6970
response = client.chat.completions.create(
7071
model=model,
71-
messages=[{"role": "user", "content": prompt}],
72+
messages=messages,
7273
stream=False,
7374
extra_body=extra_body,
7475
)
@@ -79,7 +80,7 @@ def get_completion(
7980

8081

8182
def test_openai_chat():
82-
content, session_id = get_completion(small_prompt)
83+
content, session_id = get_completion([{"role": "user", "content": small_prompt}])
8384
assert content is not None and content.strip()
8485
assert session_id is not None and session_id.strip()
8586
print(f"test_openai_chat response content: {content} session_id: {session_id}")
@@ -106,15 +107,15 @@ async def test_openai_chat_streaming():
106107

107108
def test_openai_chat_history():
108109
print("First request:", flush=True)
109-
response_content_1, session_id_1 = get_completion(small_prompt)
110+
response_content_1, session_id_1 = get_completion([{"role": "system", "content": "You are a master storyteller"},{"role": "user", "content": small_prompt}], model_id, extra_body={"enable_history": True})
110111
assert response_content_1 is not None and response_content_1.strip()
111112
assert session_id_1 is not None and session_id_1.strip()
112113
print(f"Content: {response_content_1}")
113114
print(f"Session ID: {session_id_1}\n")
114115

115116
print("\nSecond request (with session_id):", flush=True)
116117
response_content_2, session_id_2 = get_completion(
117-
small_prompt_follow_up, extra_body={"session_id": session_id_1}
118+
[{"role": "user", "content": small_prompt_follow_up}], model_id, extra_body={"session_id": session_id_1}
118119
)
119120
print(f"Content: {response_content_2}")
120121
print(f"Session ID: {session_id_2}\n")
@@ -171,7 +172,7 @@ def test_bedrock_managed_prompt():
171172

172173
# Test with a managed prompt
173174
response_content, session_id = get_completion(
174-
"", # Empty prompt as it won't be used
175+
[{"role": "user", "content": ""}], # Empty prompt as it won't be used
175176
model=managed_prompt_arn,
176177
extra_body={
177178
"promptVariables": {
@@ -220,7 +221,7 @@ async def test_bedrock_managed_prompt_streaming():
220221

221222

222223
def test_large_prompt():
223-
content, session_id = get_completion(large_prompt)
224+
content, session_id = get_completion([{"role": "user", "content": large_prompt}])
224225
assert content is not None and content.strip()
225226
assert session_id is not None and session_id.strip()
226227
print(f"test_openai_chat response content: {content} session_id: {session_id}")
@@ -238,7 +239,7 @@ def test_invalid_api_key():
238239
# Attempt to make a request with the invalid client
239240
with pytest.raises(OpenAIError) as exc_info:
240241
response = invalid_client.chat.completions.create(
241-
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
242+
model="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
242243
messages=[{"role": "user", "content": small_prompt}],
243244
stream=False,
244245
)

0 commit comments

Comments
 (0)