Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/mcp_agent/workflows/deep_orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,13 @@ async def _create_full_plan(self) -> Plan:
planning_usage = planning_node.aggregate_usage()
self.budget.update_tokens(planning_usage.total_tokens)

# Sanitize obvious issues (invalid servers/agents/dependencies) before verification
try:
plan = self._sanitize_plan(plan)
except Exception as _:
# Non-fatal; let verifier surface errors
pass

# Verify the plan
verification_result = self.plan_verifier.verify_plan(plan)

Expand Down Expand Up @@ -552,6 +559,54 @@ async def _create_full_plan(self) -> Plan:
# Should not reach here
raise RuntimeError("Failed to create a valid plan")

def _sanitize_plan(self, plan: Plan) -> Plan:
"""Best-effort sanitization to improve plan validity before verification.

- Filters task.servers to only allowed servers
- Clears unknown agent names (sets to None)
- Drops invalid requires_context_from references (non-existent or same/later step)
"""
allowed_servers = set(self.available_servers or [])
allowed_agents = set(self.agents.keys() if self.agents else [])

# Track tasks seen so far to validate dependencies
seen_tasks: set[str] = set()

for step_idx, step in enumerate(plan.steps):
# Filter each task
for task in step.tasks:
# Fix agent
if task.agent is not None and task.agent not in allowed_agents:
task.agent = None

# Filter servers
if task.servers:
filtered = [s for s in task.servers if s in allowed_servers]
# Deduplicate while preserving order
seen = set()
deduped = []
for s in filtered:
if s not in seen:
seen.add(s)
deduped.append(s)
task.servers = deduped

# Fix dependencies: only reference tasks from previous steps
if task.requires_context_from:
valid_deps: list[str] = []
for dep in task.requires_context_from:
# Valid if already seen in earlier steps
if dep in seen_tasks:
valid_deps.append(dep)
task.requires_context_from = valid_deps

# After processing this step, add its task names to seen set
for task in step.tasks:
if task.name:
seen_tasks.add(task.name)

return plan

async def _verify_completion(self) -> tuple[bool, float]:
"""
Verify if the objective has been completed.
Expand Down
123 changes: 98 additions & 25 deletions src/mcp_agent/workflows/llm/augmented_llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(self, *args, **kwargs):
self.default_request_params = self.default_request_params or RequestParams(
model=default_model,
modelPreferences=self.model_preferences,
maxTokens=4096,
maxTokens=16384, # Use gpt-4o max output token value
systemPrompt=self.instruction,
parallel_tool_calls=False,
max_iterations=10,
Expand Down Expand Up @@ -247,6 +247,7 @@ async def generate(
# DEPRECATED: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens
# "max_tokens": params.maxTokens,
"max_completion_tokens": params.maxTokens,
# Use current OpenAI reasoning API shape
"reasoning_effort": self._reasoning_effort,
}
else:
Expand Down Expand Up @@ -419,13 +420,15 @@ async def generate_str(
final_text: List[str] = []

for response in responses:
content = response.content
if not content:
continue

if isinstance(content, str):
final_text.append(content)
continue
# Robust extraction of text content
try:
text = self.message_str(response, content_only=True)
if text:
final_text.append(text)
except Exception:
content = getattr(response, "content", None)
if isinstance(content, str):
final_text.append(content)

res = "\n".join(final_text)
span.set_attribute("response", res)
Expand Down Expand Up @@ -553,13 +556,36 @@ def _ensure_no_additional_props_and_require_all(node: dict):
if not completion.choices or completion.choices[0].message.content is None:
raise ValueError("No structured content returned by model")

content = completion.choices[0].message.content
# Extract JSON text from message content (string or list)
raw_content = completion.choices[0].message.content
json_text: str | None = None
if isinstance(raw_content, str):
json_text = raw_content
else:
try:
msg_dict = completion.choices[0].message.model_dump()
parts = msg_dict.get("content", [])
if isinstance(parts, list):
texts = []
for p in parts:
if isinstance(p, dict):
if "text" in p:
texts.append(p.get("text", ""))
elif "output_text" in p:
texts.append(p.get("output_text", ""))
if texts:
json_text = "".join(texts)
except Exception:
pass

if not json_text:
raise ValueError("Structured output missing textual JSON content")

try:
data = json.loads(content)
data = json.loads(json_text)
return response_model.model_validate(data)
except Exception:
# Fallback to pydantic JSON parsing if already a JSON string-like
return response_model.model_validate_json(content)
return response_model.model_validate_json(json_text)

async def pre_tool_call(self, tool_call_id: str | None, request: CallToolRequest):
return request
Expand Down Expand Up @@ -643,13 +669,30 @@ def message_str(
self, message: ChatCompletionMessage, content_only: bool = False
) -> str:
"""Convert an output message to a string representation."""
# If simple string content
content = message.content
if content:
if isinstance(content, str):
return content
elif content_only:
# If content_only is True, return empty string if no content
return ""

# If content is a list of parts, join any available text fields
try:
msg_dict = message.model_dump()
parts = msg_dict.get("content", [])
if isinstance(parts, list):
texts = []
for p in parts:
if isinstance(p, dict):
if "text" in p:
texts.append(p.get("text", ""))
elif "output_text" in p:
texts.append(p.get("output_text", ""))
if texts:
return "".join(texts)
except Exception:
pass

if content_only:
return ""
return str(message)

def _annotate_span_for_generation_message(
Expand Down Expand Up @@ -992,6 +1035,11 @@ async def request_structured_completion_task(
"messages": [{"role": "user", "content": request.response_str}],
"response_format": response_format,
}
# Prefer reasoning API shape for reasoning-capable models
if request.model and str(request.model).startswith(
("o1", "o3", "o4", "gpt-5")
):
payload["reasoning_effort"] = "medium"
if request.user:
payload["user"] = request.user

Expand All @@ -1000,15 +1048,36 @@ async def request_structured_completion_task(
if not completion.choices or completion.choices[0].message.content is None:
raise ValueError("No structured content returned by model")

content = completion.choices[0].message.content
# message.content is expected to be JSON string
# Extract JSON from message content which may be a string or list of parts
raw_content = completion.choices[0].message.content
json_text: str | None = None
if isinstance(raw_content, str):
json_text = raw_content
else:
try:
msg_dict = completion.choices[0].message.model_dump()
parts = msg_dict.get("content", [])
if isinstance(parts, list):
texts = []
for p in parts:
if isinstance(p, dict):
if "text" in p:
texts.append(p.get("text", ""))
elif "output_text" in p:
texts.append(p.get("output_text", ""))
if texts:
json_text = "".join(texts)
except Exception:
pass

if not json_text:
raise ValueError("No structured JSON content returned by model")

try:
data = json.loads(content)
data = json.loads(json_text)
return response_model.model_validate(data)
except Exception:
# Some models may already return a dict-like; fall back to string validation
return response_model.model_validate_json(content)

return response_model.model_validate(data)
return response_model.model_validate_json(json_text)


class MCPOpenAITypeConverter(
Expand Down Expand Up @@ -1168,11 +1237,15 @@ def openai_content_to_mcp_content(
# TODO: saqadri - this is a best effort conversion, we should handle all possible content types
for c in content:
if (
c["type"] == "text"
c["type"] == "text" or c["type"] == "output_text"
): # isinstance(c, ChatCompletionContentPartTextParam):
mcp_content.append(
TextContent(
type="text", text=c["text"], **typed_dict_extras(c, ["text"])
type="text",
text=c.get("text") or c.get("output_text") or "",
**typed_dict_extras(c, ["text"])
if "text" in c
else typed_dict_extras(c, ["output_text"]),
)
)
elif (
Expand Down
Loading