Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 102 additions & 19 deletions python/packages/core/agent_framework/_workflows/_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from .._agents import AgentProtocol
from .._clients import ChatClientProtocol
from .._types import ChatMessage, Role
from .._types import ChatMessage, ChatOptions, Role
from ._agent_executor import AgentExecutorRequest, AgentExecutorResponse
from ._base_group_chat_orchestrator import BaseGroupChatOrchestrator
from ._checkpoint import CheckpointStorage
Expand Down Expand Up @@ -320,7 +320,7 @@ def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
async def _apply_directive(
self,
directive: GroupChatDirective,
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, ChatMessage],
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]],
) -> None:
"""Execute a manager directive by either finishing the workflow or routing to a participant.

Expand Down Expand Up @@ -366,7 +366,7 @@ async def _apply_directive(
self._conversation.extend((final_message,))
self._history.append(_GroupChatTurn(self._manager_name, "manager", final_message))
self._pending_agent = None
await ctx.yield_output(final_message)
await ctx.yield_output(list(self._conversation))
return

agent_name = directive.agent_name
Expand Down Expand Up @@ -415,7 +415,7 @@ async def _ingest_participant_message(
self,
participant_name: str,
message: ChatMessage,
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, ChatMessage],
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]],
) -> None:
"""Common response ingestion logic shared by agent and custom participants."""
if participant_name not in self._participants:
Expand All @@ -427,12 +427,13 @@ async def _ingest_participant_message(
self._pending_agent = None

if self._check_round_limit():
await ctx.yield_output(
self._create_completion_message(
text="Conversation halted after reaching manager round limit.",
reason="max_rounds reached after response",
)
final_message = self._create_completion_message(
text="Conversation halted after reaching manager round limit.",
reason="max_rounds reached after response",
)
self._conversation.extend((final_message,))
self._history.append(_GroupChatTurn(self._manager_name, "manager", final_message))
await ctx.yield_output(list(self._conversation))
return

directive = await self._manager(self._build_state())
Expand Down Expand Up @@ -469,7 +470,7 @@ def _extract_agent_message(response: AgentExecutorResponse, participant_name: st
async def _handle_task_message(
self,
task_message: ChatMessage,
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, ChatMessage],
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]],
) -> None:
"""Initialize orchestrator state and start the manager-directed conversation loop.

Expand Down Expand Up @@ -526,7 +527,7 @@ async def _handle_task_message(
async def handle_str(
self,
task: str,
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, ChatMessage],
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]],
) -> None:
"""Handler for string input as workflow entry point.

Expand All @@ -545,7 +546,7 @@ async def handle_str(
async def handle_chat_message(
self,
task_message: ChatMessage,
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, ChatMessage],
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]],
) -> None:
"""Handler for ChatMessage input as workflow entry point.

Expand All @@ -564,7 +565,7 @@ async def handle_chat_message(
async def handle_conversation(
self,
conversation: list[ChatMessage],
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, ChatMessage],
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]],
) -> None:
"""Handler for conversation history as workflow entry point.

Expand Down Expand Up @@ -602,7 +603,7 @@ async def handle_conversation(
async def handle_agent_response(
self,
response: _GroupChatResponseMessage,
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, ChatMessage],
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]],
) -> None:
"""Handle responses from custom participant executors."""
await self._ingest_participant_message(response.agent_name, response.message, ctx)
Expand All @@ -611,7 +612,7 @@ async def handle_agent_response(
async def handle_agent_executor_response(
self,
response: AgentExecutorResponse,
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, ChatMessage],
ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]],
) -> None:
"""Handle direct AgentExecutor responses."""
participant_name = self._registry.get_participant_name(response.executor_id)
Expand Down Expand Up @@ -748,6 +749,9 @@ class GroupChatBuilder:

.. code-block:: python

from agent_framework import GroupChatBuilder, GroupChatStateSnapshot


def select_next_speaker(state: GroupChatStateSnapshot) -> str | None:
# state contains: task, participants, conversation, history, round_index
if state["round_index"] >= 5:
Expand Down Expand Up @@ -779,6 +783,28 @@ def select_next_speaker(state: GroupChatStateSnapshot) -> str | None:
.build()
)

*Pattern 3: LLM-based selection with custom ChatOptions*

.. code-block:: python

from agent_framework import ChatOptions
from agent_framework.azure import AzureOpenAIChatClient

# Configure LLM parameters for deterministic manager decisions
chat_options = ChatOptions(temperature=0.3, seed=42, max_tokens=500)

workflow = (
GroupChatBuilder()
.set_prompt_based_manager(
chat_client=AzureOpenAIChatClient(),
display_name="Coordinator",
chat_options=chat_options,
)
.participants([researcher, writer])
.with_max_rounds(10)
.build()
)

**Participant Specification:**

Two ways to specify participants:
Expand Down Expand Up @@ -848,6 +874,7 @@ def set_prompt_based_manager(
*,
instructions: str | None = None,
display_name: str | None = None,
chat_options: ChatOptions | None = None,
) -> "GroupChatBuilder":
r"""Configure the default prompt-based manager driven by an LLM chat client.

Expand All @@ -862,6 +889,9 @@ def set_prompt_based_manager(
with the task description, participant list, and structured output format to guide
the LLM in selecting the next speaker or completing the conversation.
display_name: Optional conversational display name for manager messages.
chat_options: Optional ChatOptions to configure LLM parameters (temperature, seed, etc.)
for the manager's decision-making. These options are applied when the manager calls
the chat client to select the next speaker.

Returns:
Self for fluent chaining.
Expand All @@ -873,15 +903,23 @@ def set_prompt_based_manager(

.. code-block:: python

from agent_framework import GroupChatBuilder, DEFAULT_MANAGER_INSTRUCTIONS
from agent_framework import GroupChatBuilder, ChatOptions, DEFAULT_MANAGER_INSTRUCTIONS

custom_instructions = (
DEFAULT_MANAGER_INSTRUCTIONS + "\\n\\nPrioritize the researcher for data analysis tasks."
)

# Configure with custom temperature and seed for reproducibility
options = ChatOptions(temperature=0.3, seed=42)

workflow = (
GroupChatBuilder()
.set_prompt_based_manager(chat_client, instructions=custom_instructions, display_name="Coordinator")
.set_prompt_based_manager(
chat_client,
instructions=custom_instructions,
display_name="Coordinator",
chat_options=options,
)
.participants(researcher=researcher, writer=writer)
.build()
)
Expand All @@ -890,6 +928,7 @@ def set_prompt_based_manager(
chat_client,
instructions=instructions,
name=display_name,
chat_options=chat_options,
)
return self._set_manager_function(manager, display_name)

Expand All @@ -908,6 +947,15 @@ def select_speakers(
function receives an immutable snapshot of the current conversation state and returns
the name of the next participant to speak, or None to finish the conversation.

The selector function can implement any logic including:
- Simple round-robin or rule-based selection
- LLM-based decision making with custom prompts
- Conversation summarization before routing to the next agent
- Custom metadata or context passing

For advanced scenarios, return a GroupChatDirective instead of a string to include
custom instructions or metadata for the next participant.

The selector function signature:
def select_next_speaker(state: GroupChatStateSnapshot) -> str | None:
# state contains: task, participants, conversation, history, round_index
Expand All @@ -917,6 +965,7 @@ def select_next_speaker(state: GroupChatStateSnapshot) -> str | None:
Args:
selector: Function that takes GroupChatStateSnapshot and returns the next speaker's
name (str) to continue the conversation, or None to finish. May be sync or async.
Can also return GroupChatDirective for advanced control (instruction, metadata).
display_name: Optional name shown in conversation history for orchestrator messages
(defaults to "manager").
final_message: Optional final message (or factory) emitted when selector returns None
Expand All @@ -925,7 +974,7 @@ def select_next_speaker(state: GroupChatStateSnapshot) -> str | None:
Returns:
Self for fluent chaining.

Example:
Example (simple):

.. code-block:: python

Expand All @@ -945,6 +994,30 @@ def select_next_speaker(state: GroupChatStateSnapshot) -> str | None:
.build()
)

Example (with LLM and custom instructions):

.. code-block:: python

from agent_framework import GroupChatDirective


async def llm_based_selector(state: GroupChatStateSnapshot) -> GroupChatDirective | None:
if state["round_index"] >= 5:
return GroupChatDirective(finish=True)

# Use LLM to decide next speaker and summarize conversation
conversation_summary = await summarize_with_llm(state["conversation"])
next_agent = await pick_agent_with_llm(state["participants"], state["task"])

# Pass custom instruction to the selected agent
return GroupChatDirective(
agent_name=next_agent,
instruction=f"Context summary: {conversation_summary}",
)


workflow = GroupChatBuilder().select_speakers(llm_based_selector).participants(...).build()

Note:
Cannot be combined with set_prompt_based_manager(). Choose one orchestration strategy.
"""
Expand Down Expand Up @@ -1304,10 +1377,12 @@ def __init__(
*,
instructions: str | None = None,
name: str | None = None,
chat_options: ChatOptions | None = None,
) -> None:
self._chat_client = chat_client
self._instructions = instructions or DEFAULT_MANAGER_INSTRUCTIONS
self._name = name or "GroupChatManager"
self._chat_options = chat_options

@property
def name(self) -> str:
Expand All @@ -1332,7 +1407,15 @@ async def __call__(self, state: GroupChatStateSnapshot) -> GroupChatDirective:

messages: list[ChatMessage] = [system_message, *conversation]

response = await self._chat_client.get_response(messages, response_format=ManagerDirectiveModel)
# Apply chat options if provided, otherwise just use response_format
if self._chat_options is not None:
response = await self._chat_client.get_response(
messages,
response_format=ManagerDirectiveModel,
**self._chat_options.to_dict(),
)
else:
response = await self._chat_client.get_response(messages, response_format=ManagerDirectiveModel)

directive_model: ManagerDirectiveModel
if response.value is not None:
Expand Down
Loading
Loading