Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/packages/core/agent_framework/_workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
GroupChatDirective,
GroupChatStateSnapshot,
ManagerDirectiveModel,
ManagerSelectionRequest,
ManagerSelectionResponse,
)
from ._handoff import HandoffBuilder, HandoffUserInputRequest
from ._magentic import (
Expand Down Expand Up @@ -143,6 +145,8 @@
"MagenticPlanReviewReply",
"MagenticPlanReviewRequest",
"ManagerDirectiveModel",
"ManagerSelectionRequest",
"ManagerSelectionResponse",
"Message",
"OrchestrationState",
"RequestInfoEvent",
Expand Down
2 changes: 2 additions & 0 deletions python/packages/core/agent_framework/_workflows/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ from ._group_chat import (
GroupChatBuilder,
GroupChatDirective,
GroupChatStateSnapshot,
ManagerSelectionResponse,
)
from ._handoff import HandoffBuilder, HandoffUserInputRequest
from ._magentic import (
Expand Down Expand Up @@ -139,6 +140,7 @@ __all__ = [
"MagenticPlanReviewDecision",
"MagenticPlanReviewReply",
"MagenticPlanReviewRequest",
"ManagerSelectionResponse",
"Message",
"OrchestrationState",
"RequestInfoEvent",
Expand Down
830 changes: 626 additions & 204 deletions python/packages/core/agent_framework/_workflows/_group_chat.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -1440,6 +1440,7 @@ def _handoff_orchestrator_factory(_: _GroupChatConfig) -> Executor:

wiring = _GroupChatConfig(
manager=None,
manager_participant=None,
manager_name=self._starting_agent_id,
participants=participant_specs,
max_rounds=None,
Expand Down
51 changes: 28 additions & 23 deletions python/packages/core/agent_framework/_workflows/_magentic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ def register_agent_executor(self, name: str, executor: "MagenticAgentExecutor")

async def _emit_orchestrator_message(
self,
ctx: WorkflowContext[Any, ChatMessage],
ctx: WorkflowContext[Any, list[ChatMessage]],
message: ChatMessage,
kind: str,
) -> None:
Expand Down Expand Up @@ -1155,7 +1155,7 @@ async def handle_start_message(
self,
message: _MagenticStartMessage,
context: WorkflowContext[
_MagenticResponseMessage | _MagenticRequestMessage | _MagenticPlanReviewRequest, ChatMessage
_MagenticResponseMessage | _MagenticRequestMessage | _MagenticPlanReviewRequest, list[ChatMessage]
],
) -> None:
"""Handle the initial start message to begin orchestration."""
Expand Down Expand Up @@ -1190,7 +1190,7 @@ async def handle_start_message(

# Start the inner loop
ctx2 = cast(
WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, ChatMessage],
WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]],
context,
)
await self._run_inner_loop(ctx2)
Expand All @@ -1200,7 +1200,7 @@ async def handle_task_text(
self,
task_text: str,
context: WorkflowContext[
_MagenticResponseMessage | _MagenticRequestMessage | _MagenticPlanReviewRequest, ChatMessage
_MagenticResponseMessage | _MagenticRequestMessage | _MagenticPlanReviewRequest, list[ChatMessage]
],
) -> None:
await self.handle_start_message(_MagenticStartMessage.from_string(task_text), context)
Expand All @@ -1210,7 +1210,7 @@ async def handle_task_message(
self,
task_message: ChatMessage,
context: WorkflowContext[
_MagenticResponseMessage | _MagenticRequestMessage | _MagenticPlanReviewRequest, ChatMessage
_MagenticResponseMessage | _MagenticRequestMessage | _MagenticPlanReviewRequest, list[ChatMessage]
],
) -> None:
await self.handle_start_message(_MagenticStartMessage(task_message), context)
Expand All @@ -1220,7 +1220,7 @@ async def handle_task_messages(
self,
conversation: list[ChatMessage],
context: WorkflowContext[
_MagenticResponseMessage | _MagenticRequestMessage | _MagenticPlanReviewRequest, ChatMessage
_MagenticResponseMessage | _MagenticRequestMessage | _MagenticPlanReviewRequest, list[ChatMessage]
],
) -> None:
await self.handle_start_message(_MagenticStartMessage(conversation), context)
Expand All @@ -1229,7 +1229,7 @@ async def handle_task_messages(
async def handle_response_message(
self,
message: _MagenticResponseMessage,
context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, ChatMessage],
context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]],
) -> None:
"""Handle responses from agents."""
if getattr(self, "_terminated", False):
Expand Down Expand Up @@ -1261,7 +1261,7 @@ async def handle_plan_review_response(
response: _MagenticPlanReviewReply,
context: WorkflowContext[
# may broadcast ledger next, or ask for another round of review
_MagenticResponseMessage | _MagenticRequestMessage | _MagenticPlanReviewRequest, ChatMessage
_MagenticResponseMessage | _MagenticRequestMessage | _MagenticPlanReviewRequest, list[ChatMessage]
],
) -> None:
if getattr(self, "_terminated", False):
Expand Down Expand Up @@ -1307,7 +1307,7 @@ async def handle_plan_review_response(

# Enter the normal coordination loop
ctx2 = cast(
WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, ChatMessage],
WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]],
context,
)
await self._run_inner_loop(ctx2)
Expand All @@ -1334,7 +1334,7 @@ async def handle_plan_review_response(
self._context.chat_history.append(self._task_ledger)
# No further review requests; proceed directly into coordination
ctx2 = cast(
WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, ChatMessage],
WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]],
context,
)
await self._run_inner_loop(ctx2)
Expand Down Expand Up @@ -1369,7 +1369,7 @@ async def handle_plan_review_response(

async def _run_outer_loop(
self,
context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, ChatMessage],
context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]],
) -> None:
"""Run the outer orchestration loop - planning phase."""
if self._context is None:
Expand All @@ -1392,7 +1392,7 @@ async def _run_outer_loop(

async def _run_inner_loop(
self,
context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, ChatMessage],
context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]],
) -> None:
"""Run the inner orchestration loop. Coordination phase. Serialized with a lock."""
if self._context is None or self._task_ledger is None:
Expand All @@ -1402,7 +1402,7 @@ async def _run_inner_loop(

async def _run_inner_loop_helper(
self,
context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, ChatMessage],
context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]],
) -> None:
"""Run inner loop with exclusive access."""
# Narrow optional context for the remainder of this method
Expand Down Expand Up @@ -1487,7 +1487,7 @@ async def _run_inner_loop_helper(

async def _reset_and_replan(
self,
context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, ChatMessage],
context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]],
) -> None:
"""Reset context and replan."""
if self._context is None:
Expand All @@ -1513,7 +1513,7 @@ async def _reset_and_replan(

async def _prepare_final_answer(
self,
context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, ChatMessage],
context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]],
) -> None:
"""Prepare the final answer using the manager."""
if self._context is None:
Expand All @@ -1523,12 +1523,14 @@ async def _prepare_final_answer(
final_answer = await self._manager.prepare_final_answer(self._context.clone(deep=True))

# Emit a completed event for the workflow
await context.yield_output(final_answer)
# Yield the full conversation history including the final answer
conversation = [*list(self._context.chat_history), final_answer]
await context.yield_output(conversation)
await context.add_event(MagenticFinalResultEvent(message=final_answer))

async def _check_within_limits_or_complete(
self,
context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, ChatMessage],
context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]],
) -> bool:
"""Check if orchestrator is within operational limits."""
if self._context is None:
Expand All @@ -1555,7 +1557,9 @@ async def _check_within_limits_or_complete(
)

# Yield the partial result and signal completion
await context.yield_output(partial_result)
# Yield the full conversation history including the partial result
conversation = [*list(ctx.chat_history), partial_result]
await context.yield_output(conversation)
await context.add_event(MagenticFinalResultEvent(message=partial_result))
return False

Expand Down Expand Up @@ -2352,21 +2356,22 @@ async def _validate_checkpoint_participants(
return

# At this point, checkpoint is guaranteed to be WorkflowCheckpoint
executor_states: dict[str, Any] = checkpoint.shared_state.get(EXECUTOR_STATE_KEY, {})
executor_states = cast(dict[str, Any], checkpoint.shared_state.get(EXECUTOR_STATE_KEY, {}))
orchestrator_id = getattr(orchestrator, "id", "")
orchestrator_state = executor_states.get(orchestrator_id)
orchestrator_state = cast(Any, executor_states.get(orchestrator_id))
if orchestrator_state is None:
orchestrator_state = executor_states.get("magentic_orchestrator")
orchestrator_state = cast(Any, executor_states.get("magentic_orchestrator"))

if not isinstance(orchestrator_state, dict):
return

context_payload = orchestrator_state.get("magentic_context")
orchestrator_state_dict = cast(dict[str, Any], orchestrator_state)
context_payload = cast(Any, orchestrator_state_dict.get("magentic_context"))
if not isinstance(context_payload, dict):
return

context_dict = cast(dict[str, Any], context_payload)
restored_participants = context_dict.get("participant_descriptions")
restored_participants = cast(Any, context_dict.get("participant_descriptions"))
if not isinstance(restored_participants, dict):
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ def is_registered(self, name: str) -> bool:
"""Check if a participant is registered."""
return name in self._participant_entry_ids

def is_participant_registered(self, name: str) -> bool:
"""Check if a participant is registered (alias for is_registered for compatibility)."""
return self.is_registered(name)

def all_participants(self) -> set[str]:
"""Get all registered participant names."""
return set(self._participant_entry_ids.keys())
Loading
Loading