Skip to content

Commit 0213ea7

Browse files
gltanakaSerhan-Asadclaude
authored
fix: deep copy formatted_messages to prevent Groq message mutation (#562) (#598)
* fix: deep copy formatted_messages to prevent Groq structured output mutation corrupting fallback models (#562) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * test: trim Groq mutation tests to 3 focused regressions in test_llm_invoke.py (#562) Replace 770 lines across 2 standalone test files with 3 focused regression tests (~80 lines) in the existing test_llm_invoke.py, per PR #591 review. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Serhan <serhanasad2013@live.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4567570 commit 0213ea7

File tree

2 files changed

+124
-2
lines changed

2 files changed

+124
-2
lines changed

pdd/llm_invoke.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Corrected code_under_test (llm_invoke.py)
22
# Added optional debugging prints in _select_model_candidates
33

4+
import copy
45
import os
56
import pandas as pd
67
import litellm
@@ -1971,7 +1972,7 @@ def calc_strength(candidate):
19711972
# --- 5. Prepare LiteLLM Arguments ---
19721973
litellm_kwargs: Dict[str, Any] = {
19731974
"model": model_name_litellm,
1974-
"messages": formatted_messages,
1975+
"messages": copy.deepcopy(formatted_messages),
19751976
# Use a local adjustable temperature to allow provider-specific fallbacks
19761977
"temperature": current_temperature,
19771978
# Retry on transient network errors (APIError, TimeoutError, ServiceUnavailableError)

tests/test_llm_invoke.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Corrected unit_test (tests/test_llm_invoke.py)
22

3+
import copy
34
import pytest
45
import os
56
import pandas as pd
@@ -4863,4 +4864,124 @@ def capture_completion(**kwargs):
48634864

48644865
# time=None should be treated as 0, so no reasoning params
48654866
assert "thinking" not in captured_kwargs
4866-
assert "reasoning_effort" not in captured_kwargs
4867+
assert "reasoning_effort" not in captured_kwargs
4868+
4869+
4870+
# ---------------------------------------------------------------------------
4871+
# Issue #562: Groq structured output mutation regression tests
4872+
# ---------------------------------------------------------------------------
4873+
4874+
class _GroqMutationFixture:
4875+
"""Shared helpers for Groq message mutation tests."""
4876+
4877+
class SimpleResult(BaseModel):
4878+
answer: str
4879+
confidence: float
4880+
4881+
GROQ_SCHEMA_MARKER = "You must respond with valid JSON matching this schema"
4882+
4883+
@staticmethod
4884+
def _model(provider, model, elo, api_key):
4885+
return {
4886+
"provider": provider, "model": model, "input": 0.15,
4887+
"output": 0.60, "coding_arena_elo": elo,
4888+
"structured_output": True, "base_url": "", "api_key": api_key,
4889+
"max_tokens": "", "max_completion_tokens": "",
4890+
"reasoning_type": "none", "max_reasoning_tokens": 0,
4891+
}
4892+
4893+
@staticmethod
4894+
def _make_response(content):
4895+
resp = MagicMock()
4896+
choice = MagicMock()
4897+
choice.message.content = content
4898+
choice.finish_reason = "stop"
4899+
resp.choices = [choice]
4900+
usage = MagicMock()
4901+
usage.prompt_tokens = 100
4902+
usage.completion_tokens = 50
4903+
resp.usage = usage
4904+
return resp
4905+
4906+
4907+
class TestGroqMessageMutation(_GroqMutationFixture):
4908+
"""Regression tests for #562: Groq structured output must not corrupt
4909+
formatted_messages shared across fallback candidates."""
4910+
4911+
def _run(self, input_messages, candidates):
4912+
"""Call llm_invoke with Groq failing, capture every litellm.completion call."""
4913+
import pdd.llm_invoke as _llm_mod
4914+
4915+
groq = self._model("Groq", "groq/llama-3.3-70b-versatile", 1200, "GROQ_API_KEY")
4916+
openai = self._model("OpenAI", "gpt-4o-mini", 1100, "OPENAI_API_KEY")
4917+
models = [self._model(*c) for c in candidates] if candidates else [groq, openai]
4918+
4919+
captured = []
4920+
4921+
def side_effect(**kwargs):
4922+
captured.append(copy.deepcopy(kwargs))
4923+
if "groq/" in kwargs.get("model", ""):
4924+
raise Exception("Groq API error")
4925+
resp = self._make_response(json.dumps({"answer": "4", "confidence": 0.99}))
4926+
_llm_mod._LAST_CALLBACK_DATA["cost"] = 0.001
4927+
_llm_mod._LAST_CALLBACK_DATA["input_tokens"] = 100
4928+
_llm_mod._LAST_CALLBACK_DATA["output_tokens"] = 50
4929+
return resp
4930+
4931+
with patch.dict(os.environ, {"PDD_FORCE_LOCAL": "1",
4932+
"GROQ_API_KEY": "k", "OPENAI_API_KEY": "k"}), \
4933+
patch("pdd.llm_invoke._ensure_api_key", return_value=True), \
4934+
patch("pdd.llm_invoke._select_model_candidates", return_value=models), \
4935+
patch("pdd.llm_invoke._load_model_data", return_value=pd.DataFrame(models)), \
4936+
patch("pdd.llm_invoke.litellm") as mock_litellm:
4937+
mock_litellm.completion = MagicMock(side_effect=side_effect)
4938+
mock_litellm.cache = None
4939+
mock_litellm.drop_params = True
4940+
llm_invoke(
4941+
messages=input_messages, strength=0.5, temperature=0.0,
4942+
time=0.0, output_pydantic=self.SimpleResult, use_cloud=False,
4943+
)
4944+
return captured
4945+
4946+
def test_groq_fallback_no_schema_in_messages(self):
4947+
"""Core: after Groq fails, fallback gets clean messages (no JSON schema)."""
4948+
captured = self._run(
4949+
[{"role": "user", "content": "What is 2+2?"}],
4950+
[("Groq", "groq/llama-3.3-70b-versatile", 1200, "GROQ_API_KEY"),
4951+
("OpenAI", "gpt-4o-mini", 1100, "OPENAI_API_KEY")],
4952+
)
4953+
assert len(captured) >= 2
4954+
for msg in captured[1]["messages"]:
4955+
assert self.GROQ_SCHEMA_MARKER not in msg.get("content", ""), \
4956+
"Fallback model received Groq schema instruction"
4957+
4958+
def test_groq_fallback_system_message_not_mutated(self):
4959+
"""Dict overwrite: existing system message preserved for fallback."""
4960+
original = "You are a helpful math tutor."
4961+
captured = self._run(
4962+
[{"role": "system", "content": original},
4963+
{"role": "user", "content": "What is 2+2?"}],
4964+
[("Groq", "groq/llama-3.3-70b-versatile", 1200, "GROQ_API_KEY"),
4965+
("OpenAI", "gpt-4o-mini", 1100, "OPENAI_API_KEY")],
4966+
)
4967+
assert len(captured) >= 2
4968+
fallback_sys = captured[1]["messages"][0]
4969+
assert fallback_sys["content"] == original, \
4970+
f"System message mutated: {fallback_sys['content'][:200]}"
4971+
4972+
def test_groq_multiple_failures_no_cumulative_corruption(self):
4973+
"""Two Groq models fail; schema instructions must not accumulate."""
4974+
captured = self._run(
4975+
[{"role": "user", "content": "What is 2+2?"}],
4976+
[("Groq", "groq/llama-3.3-70b-versatile", 1200, "GROQ_API_KEY"),
4977+
("Groq", "groq/mixtral-8x7b-32768", 1150, "GROQ_API_KEY"),
4978+
("OpenAI", "gpt-4o-mini", 1100, "OPENAI_API_KEY")],
4979+
)
4980+
assert len(captured) >= 3
4981+
final = next(c for c in reversed(captured) if "groq/" not in c.get("model", ""))
4982+
schema_count = sum(
4983+
1 for m in final["messages"]
4984+
if self.GROQ_SCHEMA_MARKER in m.get("content", "")
4985+
)
4986+
assert schema_count == 0, \
4987+
f"Fallback got {schema_count} accumulated schema instructions"

0 commit comments

Comments
 (0)