Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Optional, Protocol, Union, List
import inspect

from agent_framework import AgentProtocol, AIFunction
from agent_framework import AgentProtocol, AIFunction, InMemoryCheckpointStorage
from agent_framework.azure import AzureAIClient # pylint: disable=no-name-in-module
from agent_framework._workflows import get_checkpoint_summary
from opentelemetry import trace

from azure.ai.agentserver.core.client.tools import OAuthConsentRequiredError
Expand All @@ -27,6 +28,7 @@
AgentFrameworkOutputNonStreamingConverter,
)
from .models.agent_framework_output_streaming_converter import AgentFrameworkOutputStreamingConverter
from .models.human_in_the_loop_helper import HumanInTheLoopHelper
from .models.constants import Constants
from .tool_client import ToolClient

Expand Down Expand Up @@ -85,6 +87,9 @@ def __init__(self, agent: Union[AgentProtocol, AgentFactory],
super().__init__(credentials=credentials, **kwargs) # pylint: disable=unexpected-keyword-arg
self._agent_or_factory: Union[AgentProtocol, AgentFactory] = agent
self._resolved_agent: "Optional[AgentProtocol]" = None
self._hitl_helper = HumanInTheLoopHelper()
self._checkpoint_storage = InMemoryCheckpointStorage()
self._agent_thread_in_memory = {}
# If agent is already instantiated, use it directly
if isinstance(agent, AgentProtocol):
self._resolved_agent = agent
Expand Down Expand Up @@ -187,9 +192,13 @@ def init_tracing(self):
self.tracer = trace.get_tracer(__name__)

def setup_tracing_with_azure_ai_client(self, project_endpoint: str):
logger.info("Setting up tracing with AzureAIClient")
logger.info(f"Project endpoint for tracing credential: {self.credentials}")
async def setup_async():
async with AzureAIClient(
project_endpoint=project_endpoint, async_credential=self.credentials
project_endpoint=project_endpoint,
async_credential=self.credentials,
credential=self.credentials,
) as agent_client:
await agent_client.setup_azure_ai_observability()

Expand Down Expand Up @@ -223,24 +232,47 @@ async def agent_run( # pylint: disable=too-many-statements

logger.info(f"Starting agent_run with stream={context.stream}")
request_input = context.request.get("input")

input_converter = AgentFrameworkInputConverter()
message = input_converter.transform_input(request_input)
# TODO: load agent thread from storage and deserialize
agent_thread = self._agent_thread_in_memory.get(context.conversation_id, agent.get_new_thread())

last_checkpoint = None
if self._checkpoint_storage:
checkpoints = await self._checkpoint_storage.list_checkpoints()
last_checkpoint = checkpoints[-1] if len(checkpoints) > 0 else None
logger.info(f"Last checkpoint data: {last_checkpoint.to_dict() if last_checkpoint else 'None'}")
if last_checkpoint:
summary = get_checkpoint_summary(last_checkpoint)
logger.info(f"Last checkpoint summary status: {summary.status}")
if summary.status == "completed":
last_checkpoint = None # Do not resume from completed checkpoints

input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper)
message = await input_converter.transform_input(
request_input,
agent_thread=agent_thread,
checkpoint=last_checkpoint)
logger.debug(f"Transformed input message type: {type(message)}")

# Use split converters
if context.stream:
logger.info("Running agent in streaming mode")
streaming_converter = AgentFrameworkOutputStreamingConverter(context)
streaming_converter = AgentFrameworkOutputStreamingConverter(context, hitl_helper=self._hitl_helper)

async def stream_updates():
try:
update_count = 0
updates = agent.run_stream(message)
updates = agent.run_stream(
message,
thread=agent_thread,
checkpoint_storage=self._checkpoint_storage,
checkpoint_id=last_checkpoint.checkpoint_id if last_checkpoint else None,
)
async for event in streaming_converter.convert(updates):
update_count += 1
yield event


if agent_thread:
self._agent_thread_in_memory[context.conversation_id] = agent_thread
logger.info("Streaming completed with %d updates", update_count)
finally:
# Close tool_client if it was created for this request
Expand All @@ -255,9 +287,15 @@ async def stream_updates():

# Non-streaming path
logger.info("Running agent in non-streaming mode")
non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context)
result = await agent.run(message)
logger.debug(f"Agent run completed, result type: {type(result)}")
non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context, hitl_helper=self._hitl_helper)
result = await agent.run(message,
thread=agent_thread,
checkpoint_storage=self._checkpoint_storage,
checkpoint_id=last_checkpoint.checkpoint_id if last_checkpoint else None,
)
logger.info(f"Agent run completed, result type: {type(result)}")
if agent_thread:
self._agent_thread_in_memory[context.conversation_id] = agent_thread
transformed_result = non_streaming_converter.transform_output_for_response(result)
logger.info("Agent run and transformation completed successfully")
return transformed_result
Expand All @@ -279,3 +317,4 @@ async def oauth_consent_stream(error=e):
logger.debug("Closed tool_client after request processing")
except Exception as ex: # pylint: disable=broad-exception-caught
logger.warning(f"Error closing tool_client: {ex}")

Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
# mypy: disable-error-code="no-redef"
from __future__ import annotations

from typing import Dict, List

from agent_framework import ChatMessage, Role as ChatRole
from typing import Dict, List, Optional

from agent_framework import (
AgentThread,
ChatMessage,
RequestInfoEvent,
Role as ChatRole,
WorkflowCheckpoint,
)
from agent_framework._types import TextContent

from azure.ai.agentserver.core.logger import get_logger
Expand All @@ -21,10 +27,14 @@ class AgentFrameworkInputConverter:
Accepts: str | List | None
Returns: None | str | ChatMessage | list[str] | list[ChatMessage]
"""
def __init__(self, *, hitl_helper=None) -> None:
self._hitl_helper = hitl_helper

def transform_input(
async def transform_input(
self,
input: str | List[Dict] | None,
agent_thread: Optional[AgentThread] = None,
checkpoint: Optional[WorkflowCheckpoint] = None,
) -> str | ChatMessage | list[str] | list[ChatMessage] | None:
logger.debug("Transforming input of type: %s", type(input))

Expand All @@ -33,7 +43,28 @@ def transform_input(

if isinstance(input, str):
return input


if self._hitl_helper:
# load pending requests from checkpoint and thread messages if available
thread_messages = []
if agent_thread:
thread_messages = await agent_thread.message_store.list_messages()
logger.info(f"Thread messages count: {len(thread_messages)}")
pending_hitl_requests = self._hitl_helper.get_pending_hitl_request(thread_messages, checkpoint)
logger.info(f"Pending HitL requests: {list(pending_hitl_requests.keys())}")
hitl_response = self._hitl_helper.validate_and_convert_hitl_response(
input,
pending_requests=pending_hitl_requests)
logger.info(f"HitL response validation result: {[m.to_dict() for m in hitl_response]}")
if hitl_response:
return hitl_response

return self._transform_input_internal(input)

def _transform_input_internal(
self,
input: str | List[Dict] | None,
) -> str | ChatMessage | list[str] | list[ChatMessage] | None:
try:
if isinstance(input, list):
messages: list[str | ChatMessage] = []
Expand Down Expand Up @@ -118,3 +149,35 @@ def _extract_input_text(self, content_item: Dict) -> str:
if isinstance(text_content, str):
return text_content
return None # type: ignore

def _validate_and_convert_hitl_response(
self,
pending_request: Dict,
input: List[Dict],
) -> Optional[List[ChatMessage]]:
if not self._hitl_helper:
logger.warning("HitL helper not provided; cannot validate HitL response.")
return None
if isinstance(input, str):
logger.warning("Expected list input for HitL response validation, got str.")
return None
if not isinstance(input, list) or len(input) != 1:
logger.warning("Expected single-item list input for HitL response validation.")
return None

item = input[0]
if item.get("type") != "function_call_output":
logger.warning("Expected function_call_output type for HitL response validation.")
return None
call_id = item.get("call_id", None)
if not call_id or call_id not in pending_request:
logger.warning("Function call output missing valid call_id for HitL response validation.")
return None
request_info = pending_request[call_id]
if isinstance(request_info, dict):
request_info = RequestInfoEvent.from_dict(request_info)
if not isinstance(request_info, RequestInfoEvent):
logger.warning("No valid pending request info found for call_id: %s", call_id)
return None

return self._hitl_helper.convert_response(request_info, item)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
import json
from typing import Any, List

from agent_framework import AgentRunResponse, FunctionCallContent, FunctionResultContent, ErrorContent, TextContent
from agent_framework import (
AgentRunResponse,
FunctionCallContent,
FunctionResultContent,
ErrorContent,
TextContent,
)
from agent_framework._types import UserInputRequestContents

from azure.ai.agentserver.core import AgentRunContext
from azure.ai.agentserver.core.logger import get_logger
Expand All @@ -21,17 +28,19 @@

from .agent_id_generator import AgentIdGenerator
from .constants import Constants
from .human_in_the_loop_helper import HumanInTheLoopHelper

logger = get_logger()


class AgentFrameworkOutputNonStreamingConverter: # pylint: disable=name-too-long
"""Non-streaming converter: AgentRunResponse -> OpenAIResponse."""

def __init__(self, context: AgentRunContext):
def __init__(self, context: AgentRunContext, *, hitl_helper: HumanInTheLoopHelper):
self._context = context
self._response_id = None
self._response_created_at = None
self._hitl_helper = hitl_helper

def _ensure_response_started(self) -> None:
if not self._response_id:
Expand Down Expand Up @@ -120,6 +129,8 @@ def _append_content_item(self, content: Any, sink: List[dict], author_name: str)
self._append_function_call_content(content, sink, author_name)
elif isinstance(content, FunctionResultContent):
self._append_function_result_content(content, sink, author_name)
elif isinstance(content, UserInputRequestContents):
self._append_user_input_request_contents(content, sink, author_name)
elif isinstance(content, ErrorContent):
raise ValueError(f"ErrorContent received: code={content.error_code}, message={content.message}")
else:
Expand Down Expand Up @@ -205,6 +216,22 @@ def _append_function_result_content(self, content: FunctionResultContent, sink:
call_id,
len(result),
)

def _append_user_input_request_contents(self, content: UserInputRequestContents, sink: List[dict], author_name: str) -> None:
item_id = self._context.id_generator.generate_function_call_id()
content = self._hitl_helper.convert_user_input_request_content(content)
sink.append(
{
"id": item_id,
"type": "function_call",
"status": "in_progress",
"call_id": content["call_id"],
"name": content["name"],
"arguments": content["arguments"],
"created_by": self._build_created_by(author_name),
}
)
logger.debug(" added user_input_request item id=%s call_id=%s", item_id, content["call_id"])

# ------------- simple normalization helper -------------------------
def _coerce_result_text(self, value: Any) -> str | dict:
Expand Down
Loading