diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b161de17..4c9114b3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.12.8 + rev: v0.14.0 hooks: # Run the linter. - id: ruff-check diff --git a/agents-core/vision_agents/_generate_sfu_events.py b/agents-core/vision_agents/_generate_sfu_events.py index 7c1d29a2..d039a245 100644 --- a/agents-core/vision_agents/_generate_sfu_events.py +++ b/agents-core/vision_agents/_generate_sfu_events.py @@ -8,7 +8,16 @@ from __future__ import annotations import pathlib -from typing import Dict as TypingDict, Iterable, List, Optional, Sequence, Set, Tuple, Type +from typing import ( + Dict as TypingDict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Type, +) from google.protobuf.descriptor import FieldDescriptor from google.protobuf.message import Message @@ -59,16 +68,16 @@ def _collect_message_types() -> TypingDict[str, Type[Message]]: """Collect all message types referenced in event fields (recursively).""" message_types: TypingDict[str, Type[Message]] = {} to_process: Set[str] = set() - + # Collect all message types from events for proto_name, message_cls in _iter_protobuf_messages(): for field_desc in message_cls.DESCRIPTOR.fields: if field_desc.type == FieldDescriptor.TYPE_MESSAGE: message_type_name = field_desc.message_type.full_name - if message_type_name.startswith('stream.video.sfu.model'): - class_name = message_type_name.split('.')[-1] + if message_type_name.startswith("stream.video.sfu.model"): + class_name = message_type_name.split(".")[-1] to_process.add(class_name) - + # Process messages and their nested message types recursively processed: Set[str] = set() while to_process: @@ -76,20 +85,20 @@ def _collect_message_types() -> TypingDict[str, Type[Message]]: if class_name in processed: continue processed.add(class_name) - + if hasattr(models_pb2, class_name): message_cls = getattr(models_pb2, class_name) message_types[class_name] = message_cls - + # Check for nested message types for field_desc in message_cls.DESCRIPTOR.fields: if field_desc.type == FieldDescriptor.TYPE_MESSAGE: nested_type_name = field_desc.message_type.full_name - if nested_type_name.startswith('stream.video.sfu.model'): - nested_class_name = nested_type_name.split('.')[-1] + if nested_type_name.startswith("stream.video.sfu.model"): + nested_class_name = nested_type_name.split(".")[-1] if nested_class_name not in processed: to_process.add(nested_class_name) - + return message_types @@ -97,13 +106,13 @@ def _get_message_type_name(field_descriptor: FieldDescriptor) -> Optional[str]: """Get the wrapper class name for a message type field.""" if field_descriptor.type == FieldDescriptor.TYPE_MESSAGE: message_type_name = field_descriptor.message_type.full_name - + # Map known message types to their wrapper classes - if message_type_name.startswith('stream.video.sfu.model'): - class_name = message_type_name.split('.')[-1] + if message_type_name.startswith("stream.video.sfu.model"): + class_name = message_type_name.split(".")[-1] # Return the class name which will be defined at the top return class_name - + return None @@ -112,14 +121,14 @@ def _get_enum_type_name(field_descriptor: FieldDescriptor) -> Optional[str]: if field_descriptor.type == FieldDescriptor.TYPE_ENUM: enum_type_name = field_descriptor.enum_type.full_name # Extract just the class name from the full name (e.g., stream.video.sfu.models.TrackType -> TrackType) - if enum_type_name.startswith('stream.video.sfu.models.'): - return enum_type_name.split('.')[-1] + if enum_type_name.startswith("stream.video.sfu.models."): + return enum_type_name.split(".")[-1] return None def _get_python_type_from_protobuf_field(field_descriptor: FieldDescriptor) -> str: """Determine Python type from protobuf field descriptor. - + Maps protobuf field types to their corresponding Python types. All fields are returned as Optional since we want optional semantics. """ @@ -141,7 +150,7 @@ def _get_python_type_from_protobuf_field(field_descriptor: FieldDescriptor) -> s FieldDescriptor.TYPE_SINT32: "int", FieldDescriptor.TYPE_SINT64: "int", } - + # Handle repeated fields (lists) if field_descriptor.is_repeated: # For enum types in repeated fields - use int, documented in docstring @@ -156,16 +165,16 @@ def _get_python_type_from_protobuf_field(field_descriptor: FieldDescriptor) -> s else: base_type = type_map.get(field_descriptor.type, "Any") return f"Optional[List[{base_type}]]" - + # Handle message types (nested protobuf messages) if field_descriptor.type == FieldDescriptor.TYPE_MESSAGE: message_type = _get_message_type_name(field_descriptor) return f"Optional[{message_type}]" if message_type else "Optional[Any]" - + # Handle enum types - use int, documented in docstring if field_descriptor.type == FieldDescriptor.TYPE_ENUM: return "Optional[int]" - + # Handle scalar types - all made optional base_type = type_map.get(field_descriptor.type, "Any") return f"Optional[{base_type}]" @@ -174,7 +183,7 @@ def _get_python_type_from_protobuf_field(field_descriptor: FieldDescriptor) -> s def _render_message_wrapper(class_name: str, message_cls: Type[Message]) -> List[str]: """Generate a dataclass wrapper for a protobuf message type (like Participant).""" lines = ["@dataclass", f"class {class_name}(DataClassJsonMixin):"] - + # Build docstring with enum field documentation docstring_lines = [f"Wrapper for {message_cls.DESCRIPTOR.full_name}."] field_descriptors = message_cls.DESCRIPTOR.fields @@ -184,70 +193,80 @@ def _render_message_wrapper(class_name: str, message_cls: Type[Message]) -> List enum_type_name = _get_enum_type_name(field_desc) if enum_type_name: enum_fields.append((field_desc.name, enum_type_name)) - + if enum_fields: docstring_lines.append("") docstring_lines.append("Enum fields (use values from models_pb2):") for field_name, enum_name in enum_fields: docstring_lines.append(f" - {field_name}: {enum_name}") - + # Add docstring if len(docstring_lines) == 1: - lines.append(f" \"\"\"{docstring_lines[0]}\"\"\"") + lines.append(f' """{docstring_lines[0]}"""') else: - lines.append(" \"\"\"" + docstring_lines[0]) + lines.append(' """' + docstring_lines[0]) for line in docstring_lines[1:]: lines.append(" " + line) - lines.append(" \"\"\"") - + lines.append(' """') + if not field_descriptors: lines.append(" pass") lines.append("") return lines - + # Generate fields with proper types for field_desc in field_descriptors: field_name = field_desc.name python_type = _get_python_type_from_protobuf_field(field_desc) - + # Use proper default for optional fields lines.append(f" {field_name}: {python_type} = None") - + lines.append("") lines.append(" @classmethod") lines.append(f" def from_proto(cls, proto_obj) -> '{class_name}':") - lines.append(f" \"\"\"Create from protobuf {class_name}.\"\"\"") + lines.append(f' """Create from protobuf {class_name}."""') lines.append(" if proto_obj is None:") lines.append(" return cls()") lines.append(" return cls(") - + # Generate field assignments for i, field_desc in enumerate(field_descriptors): field_name = field_desc.name comma = "," if i < len(field_descriptors) - 1 else "" - + # Handle different field types if field_desc.type == FieldDescriptor.TYPE_MESSAGE and field_desc.is_repeated: # Repeated message fields message_type = _get_message_type_name(field_desc) if message_type: - lines.append(f" {field_name}=[{message_type}.from_proto(item) for item in proto_obj.{field_name}]{comma}") + lines.append( + f" {field_name}=[{message_type}.from_proto(item) for item in proto_obj.{field_name}]{comma}" + ) else: - lines.append(f" {field_name}=list(proto_obj.{field_name}){comma}") + lines.append( + f" {field_name}=list(proto_obj.{field_name}){comma}" + ) elif field_desc.type == FieldDescriptor.TYPE_MESSAGE: # Single message field message_type = _get_message_type_name(field_desc) if message_type: - lines.append(f" {field_name}={message_type}.from_proto(proto_obj.{field_name}) if proto_obj.HasField('{field_name}') else None{comma}") + lines.append( + f" {field_name}={message_type}.from_proto(proto_obj.{field_name}) if proto_obj.HasField('{field_name}') else None{comma}" + ) else: - lines.append(f" {field_name}=proto_obj.{field_name} if proto_obj.HasField('{field_name}') else None{comma}") + lines.append( + f" {field_name}=proto_obj.{field_name} if proto_obj.HasField('{field_name}') else None{comma}" + ) elif field_desc.is_repeated: # Repeated scalar/enum fields - convert to list of ints for enums - lines.append(f" {field_name}=list(proto_obj.{field_name}){comma}") + lines.append( + f" {field_name}=list(proto_obj.{field_name}){comma}" + ) else: # Regular scalar/enum fields lines.append(f" {field_name}=proto_obj.{field_name}{comma}") - + lines.append(" )") lines.append("") return lines @@ -256,94 +275,116 @@ def _render_message_wrapper(class_name: str, message_cls: Type[Message]) -> List def _render_class(proto_name: str, message_cls: Type[Message]) -> List[str]: class_name = _class_name(proto_name) event_type = message_cls.DESCRIPTOR.full_name - + # Get field descriptors for this message field_descriptors = message_cls.DESCRIPTOR.fields lines = ["@dataclass", f"class {class_name}(BaseEvent):"] - lines.append(f" \"\"\"Dataclass event for {message_cls.__module__}.{message_cls.__name__}.\"\"\"") - + lines.append( + f' """Dataclass event for {message_cls.__module__}.{message_cls.__name__}."""' + ) + # Override type field with the specific event type - lines.append(f" type: str = field(default=\"{event_type}\", init=False)") - + lines.append(f' type: str = field(default="{event_type}", init=False)') + # Add payload field (optional to match BaseEvent pattern) - lines.append(f" payload: Optional[events_pb2.{proto_name}] = field(default=None, repr=False)") - + lines.append( + f" payload: Optional[events_pb2.{proto_name}] = field(default=None, repr=False)" + ) + # Add property fields for each protobuf field (skip fields that conflict with BaseEvent) base_event_fields = {"type", "event_id", "timestamp", "session_id", "user_metadata"} for field_desc in field_descriptors: field_name = field_desc.name - if field_name in base_event_fields: # Skip fields that conflict with BaseEvent fields + if ( + field_name in base_event_fields + ): # Skip fields that conflict with BaseEvent fields continue type_hint = _get_python_type_from_protobuf_field(field_desc) lines.append("") lines.append(" @property") lines.append(f" def {field_name}(self) -> {type_hint}:") - + # Build docstring with enum information if applicable docstring = f"Access {field_name} field from the protobuf payload." if field_desc.type == FieldDescriptor.TYPE_ENUM: enum_type_name = _get_enum_type_name(field_desc) if enum_type_name: docstring += f" Use models_pb2.{enum_type_name} enum." - - lines.append(f" \"\"\"{docstring}\"\"\"") + + lines.append(f' """{docstring}"""') lines.append(" if self.payload is None:") lines.append(" return None") - + # Handle message type fields - wrap them in our dataclass if field_desc.type == FieldDescriptor.TYPE_MESSAGE: message_type = _get_message_type_name(field_desc) if message_type: if field_desc.is_repeated: - lines.append(f" proto_list = getattr(self.payload, '{field_name}', [])") - lines.append(f" return [{message_type}.from_proto(item) for item in proto_list] if proto_list else None") + lines.append( + f" proto_list = getattr(self.payload, '{field_name}', [])" + ) + lines.append( + f" return [{message_type}.from_proto(item) for item in proto_list] if proto_list else None" + ) else: - lines.append(f" proto_val = getattr(self.payload, '{field_name}', None)") - lines.append(f" return {message_type}.from_proto(proto_val) if proto_val is not None else None") + lines.append( + f" proto_val = getattr(self.payload, '{field_name}', None)" + ) + lines.append( + f" return {message_type}.from_proto(proto_val) if proto_val is not None else None" + ) else: - lines.append(f" return getattr(self.payload, '{field_name}', None)") + lines.append( + f" return getattr(self.payload, '{field_name}', None)" + ) else: # Scalar or enum fields lines.append(f" return getattr(self.payload, '{field_name}', None)") - + lines.append("") lines.append(" @classmethod") - lines.append(" def from_proto(cls, proto_obj: events_pb2.{0}, **extra):".format(proto_name)) - lines.append(" \"\"\"Create event instance from protobuf message.\"\"\"") + lines.append( + " def from_proto(cls, proto_obj: events_pb2.{0}, **extra):".format( + proto_name + ) + ) + lines.append(' """Create event instance from protobuf message."""') lines.append(" return cls(payload=proto_obj, **extra)") lines.append("") lines.append(" def as_dict(self) -> Dict[str, Any]:") - lines.append(" \"\"\"Convert protobuf payload to dictionary.\"\"\"") + lines.append(' """Convert protobuf payload to dictionary."""') lines.append(" if self.payload is None:") lines.append(" return {}") lines.append(" return _to_dict(self.payload)") lines.append("") lines.append(" def __getattr__(self, item: str):") - lines.append(" \"\"\"Delegate attribute access to protobuf payload.\"\"\"") + lines.append(' """Delegate attribute access to protobuf payload."""') lines.append(" if self.payload is not None:") lines.append(" return getattr(self.payload, item)") - lines.append(" raise AttributeError(f\"'{self.__class__.__name__}' object has no attribute '{item}'\")") + lines.append( + " raise AttributeError(f\"'{self.__class__.__name__}' object has no attribute '{item}'\")" + ) lines.append("") return lines def _render_module_body() -> Tuple[List[str], List[str], List[str]]: """Generate message wrappers and event classes. - + Returns: Tuple of (message_wrapper_blocks, event_class_blocks, event_class_names) """ # Collect all message types used in events message_types = _collect_message_types() - + # Generate message wrapper classes message_wrapper_blocks: List[str] = [] for class_name in sorted(message_types.keys()): message_cls = message_types[class_name] wrapper_lines = _render_message_wrapper(class_name, message_cls) message_wrapper_blocks.append("\n".join(wrapper_lines)) - + # Generate event classes event_class_blocks: List[str] = [] event_class_names: List[str] = [] @@ -361,37 +402,41 @@ def _build_module() -> str: message_wrappers, event_classes, event_names = _render_module_body() parts: List[str] = [ - "\"\"\"Auto-generated SFU event dataclasses. Do not edit manually.\"\"\"", + '"""Auto-generated SFU event dataclasses. Do not edit manually."""', "# Generated by _generate_sfu_events.py", *HEADER_LINES, ] - + # Add section header for message wrappers if message_wrappers: - parts.extend([ - "# " + "=" * 78, - "# Message Type Wrappers", - "# These are wrappers for protobuf message types used in events", - "# " + "=" * 78, - "", - ]) + parts.extend( + [ + "# " + "=" * 78, + "# Message Type Wrappers", + "# These are wrappers for protobuf message types used in events", + "# " + "=" * 78, + "", + ] + ) parts.extend(message_wrappers) parts.append("") - + # Add section header for event classes - parts.extend([ - "# " + "=" * 78, - "# Event Classes", - "# " + "=" * 78, - "", - ]) + parts.extend( + [ + "# " + "=" * 78, + "# Event Classes", + "# " + "=" * 78, + "", + ] + ) parts.extend(event_classes) # Add exports exports_section = [ "", "__all__ = (", - *[f" \"{name}\"," for name in event_names], + *[f' "{name}",' for name in event_names], ")", ] @@ -402,113 +447,135 @@ def _build_module() -> str: def verify_generated_classes() -> bool: """Verify that generated classes match protobuf definitions. - + Returns: True if all checks pass, False otherwise. """ import importlib.util import sys - + # Import the generated module target_path = pathlib.Path(__file__).parent / "core" / "edge" / "sfu_events.py" if not target_path.exists(): print("Error: sfu_events.py not found. Run generation first.") return False - + # Dynamically load the module spec = importlib.util.spec_from_file_location("sfu_events", target_path) if spec is None or spec.loader is None: print("Error: Could not load sfu_events module") return False - + sfu_events = importlib.util.module_from_spec(spec) sys.modules["sfu_events"] = sfu_events spec.loader.exec_module(sfu_events) - + all_valid = True - + for proto_name, message_cls in _iter_protobuf_messages(): class_name = _class_name(proto_name) - + # Check if class exists in generated module if not hasattr(sfu_events, class_name): print(f"✗ Class {class_name} not found in generated module") all_valid = False continue - + event_class = getattr(sfu_events, class_name) - + # Verify it's a BaseEvent subclass - if not hasattr(event_class, '__mro__'): + if not hasattr(event_class, "__mro__"): print(f"✗ {class_name} is not a class") all_valid = False continue - + # Check field correspondence proto_fields = {f.name: f for f in message_cls.DESCRIPTOR.fields} - + # Check that all protobuf fields are accessible via properties for field_name, field_desc in proto_fields.items(): - if field_name in {"type", "event_id", "timestamp", "session_id", "user_metadata"}: + if field_name in { + "type", + "event_id", + "timestamp", + "session_id", + "user_metadata", + }: continue # Skip BaseEvent fields - + if not hasattr(event_class, field_name): - print(f"✗ {class_name} missing property for protobuf field: {field_name}") + print( + f"✗ {class_name} missing property for protobuf field: {field_name}" + ) all_valid = False continue - + # Verify it's a property (check on the class itself, not an instance) attr = getattr(event_class, field_name, None) if not isinstance(attr, property): - print(f"✗ {class_name}.{field_name} is not a property (type: {type(attr).__name__})") + print( + f"✗ {class_name}.{field_name} is not a property (type: {type(attr).__name__})" + ) all_valid = False continue - + print(f"✓ {class_name} verified ({len(proto_fields)} protobuf fields)") - + return all_valid def verify_field_types() -> None: """Verify and display field type mappings for all protobuf messages.""" - print("\n" + "="*80) + print("\n" + "=" * 80) print("Field Type Verification Report") - print("="*80 + "\n") - + print("=" * 80 + "\n") + for proto_name, message_cls in _iter_protobuf_messages(): class_name = _class_name(proto_name) print(f"\n{class_name} ({proto_name}):") print(f" Protobuf type: {message_cls.DESCRIPTOR.full_name}") - + field_descriptors = message_cls.DESCRIPTOR.fields if not field_descriptors: print(" (no fields)") continue - + for field_desc in field_descriptors: field_name = field_desc.name - if field_name in {"type", "event_id", "timestamp", "session_id", "user_metadata"}: + if field_name in { + "type", + "event_id", + "timestamp", + "session_id", + "user_metadata", + }: continue - + python_type = _get_python_type_from_protobuf_field(field_desc) proto_type_name = field_desc.type - label = "repeated" if field_desc.is_repeated else "optional" if hasattr(field_desc, "is_optional") else "required" - + label = ( + "repeated" + if field_desc.is_repeated + else "optional" + if hasattr(field_desc, "is_optional") + else "required" + ) + print(f" - {field_name}: type={proto_type_name} ({label}) → {python_type}") def main() -> None: import sys - + # Generate sfu_events.py in the core/edge directory target_path = pathlib.Path(__file__).parent / "core" / "edge" / "sfu_events.py" target_path.write_text(_build_module(), encoding="utf-8") print(f"Regenerated {target_path}") - + # Verify field types if "--verify-types" in sys.argv: verify_field_types() - + # Verify generated classes if "--verify" in sys.argv: print("\nVerifying generated classes...") @@ -521,4 +588,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/agents-core/vision_agents/core/__init__.py b/agents-core/vision_agents/core/__init__.py index bd69fced..cff20e87 100644 --- a/agents-core/vision_agents/core/__init__.py +++ b/agents-core/vision_agents/core/__init__.py @@ -2,4 +2,4 @@ from vision_agents.core.agents import Agent -__all__ = ["Agent", "User"] \ No newline at end of file +__all__ = ["Agent", "User"] diff --git a/agents-core/vision_agents/core/agents/agent_session.py b/agents-core/vision_agents/core/agents/agent_session.py index 9e1fbfc9..a50a635e 100644 --- a/agents-core/vision_agents/core/agents/agent_session.py +++ b/agents-core/vision_agents/core/agents/agent_session.py @@ -26,6 +26,7 @@ class AgentSessionContextManager: connection_cm: Optional provider-specific connection context manager returned by the edge transport (kept open during the context). """ + def __init__(self, agent: Agent, connection_cm=None): self.agent = agent self._connection_cm = connection_cm diff --git a/agents-core/vision_agents/core/agents/agents.py b/agents-core/vision_agents/core/agents/agents.py index bdfba57e..38ae818e 100644 --- a/agents-core/vision_agents/core/agents/agents.py +++ b/agents-core/vision_agents/core/agents/agents.py @@ -50,6 +50,7 @@ def _log_task_exception(task: asyncio.Task): except Exception: logger.exception("Error in background task") + class Agent: """ Agent class makes it easy to build your own video AI. @@ -124,7 +125,11 @@ def __init__( self._call_context_token: CallContextToken | None = None # Initialize MCP manager if servers are provided - self.mcp_manager = MCPManager(self.mcp_servers, self.llm, self.logger) if self.mcp_servers else None + self.mcp_manager = ( + MCPManager(self.mcp_servers, self.llm, self.logger) + if self.mcp_servers + else None + ) # we sync the user talking and the agent responses to the conversation # because we want to support streaming responses and can have delta updates for both @@ -132,7 +137,7 @@ def __init__( self.conversation: Optional[Conversation] = None self._user_conversation_handle: Optional[StreamHandle] = None self._agent_conversation_handle: Optional[StreamHandle] = None - + # Track pending transcripts for turn-based response triggering self._pending_user_transcripts: Dict[str, str] = {} @@ -153,7 +158,7 @@ def __init__( self._current_frame = None self._interval_task = None self._callback_executed = False - self._track_tasks : Dict[str, asyncio.Task] = {} + self._track_tasks: Dict[str, asyncio.Task] = {} self._connection: Optional[Connection] = None self._audio_track: Optional[aiortc.AudioStreamTrack] = None self._video_track: Optional[VideoStreamTrack] = None @@ -194,7 +199,6 @@ def subscribe(self, function): """ return self.events.subscribe(function) - async def join(self, call: Call) -> "AgentSessionContextManager": # TODO: validation. join can only be called once with self.tracer.start_as_current_span("join"): @@ -311,9 +315,9 @@ async def close(self): for processor in self.processors: processor.close() - + # Stop all video forwarders - if hasattr(self, '_video_forwarders'): + if hasattr(self, "_video_forwarders"): for forwarder in self._video_forwarders: try: await forwarder.stop() @@ -499,23 +503,30 @@ async def _on_agent_say(self, event: events.AgentSayEvent): ) self.logger.error(f"Error in agent say: {e}") - async def say(self, text: str, user_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None): + async def say( + self, + text: str, + user_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): """ Make the agent say something using TTS. - + This is a convenience method that sends an AgentSayEvent to trigger TTS synthesis. - + Args: text: The text for the agent to say user_id: Optional user ID for the speech metadata: Optional metadata to include with the speech """ - self.events.send(events.AgentSayEvent( - plugin_name="agent", - text=text, - user_id=user_id or self.agent_user.id, - metadata=metadata - )) + self.events.send( + events.AgentSayEvent( + plugin_name="agent", + text=text, + user_id=user_id or self.agent_user.id, + metadata=metadata, + ) + ) def _setup_turn_detection(self): if self.turn_detection: @@ -571,12 +582,11 @@ async def _reply_to_audio( continue await processor.process_audio(audio_bytes, participant.user_id) - # when in Realtime mode call the Realtime directly (non-blocking) if self.realtime_mode and isinstance(self.llm, Realtime): # TODO: this behaviour should be easy to change in the agent class asyncio.create_task(self.llm.simple_audio_response(pcm_data)) - #task.add_done_callback(lambda t: print(f"Task (send_audio_pcm) error: {t.exception()}")) + # task.add_done_callback(lambda t: print(f"Task (send_audio_pcm) error: {t.exception()}")) # Process audio through STT elif self.stt: self.logger.debug(f"🎵 Processing audio from {participant}") @@ -591,14 +601,12 @@ async def _process_track(self, track_id: str, track_type: int, participant): # subscribe to the video track track = self.edge.add_track_subscriber(track_id) if not track: - self.logger.error( - f"Failed to subscribe to {track_id}" - ) + self.logger.error(f"Failed to subscribe to {track_id}") return # Import VideoForwarder from ..utils.video_forwarder import VideoForwarder - + # Create a SHARED VideoForwarder for the RAW incoming track # This prevents multiple recv() calls competing on the same track raw_forwarder = VideoForwarder( @@ -609,9 +617,9 @@ async def _process_track(self, track_id: str, track_type: int, participant): ) await raw_forwarder.start() self.logger.info("🎥 Created raw VideoForwarder for track %s", track_id) - + # Track forwarders for cleanup - if not hasattr(self, '_video_forwarders'): + if not hasattr(self, "_video_forwarders"): self._video_forwarders = [] self._video_forwarders.append(raw_forwarder) @@ -620,7 +628,9 @@ async def _process_track(self, track_id: str, track_type: int, participant): if self._video_track: # We have a video publisher (e.g., YOLO processor) # Create a separate forwarder for the PROCESSED video track - self.logger.info("🎥 Forwarding PROCESSED video frames to Realtime provider") + self.logger.info( + "🎥 Forwarding PROCESSED video frames to Realtime provider" + ) processed_forwarder = VideoForwarder( self._video_track, # type: ignore[arg-type] max_buffer=30, @@ -629,23 +639,28 @@ async def _process_track(self, track_id: str, track_type: int, participant): ) await processed_forwarder.start() self._video_forwarders.append(processed_forwarder) - + if isinstance(self.llm, Realtime): # Send PROCESSED frames with the processed forwarder - await self.llm._watch_video_track(self._video_track, shared_forwarder=processed_forwarder) + await self.llm._watch_video_track( + self._video_track, shared_forwarder=processed_forwarder + ) else: # No video publisher, send raw frames self.logger.info("🎥 Forwarding RAW video frames to Realtime provider") if isinstance(self.llm, Realtime): - await self.llm._watch_video_track(track, shared_forwarder=raw_forwarder) - + await self.llm._watch_video_track( + track, shared_forwarder=raw_forwarder + ) hasImageProcessers = len(self.image_processors) > 0 # video processors - pass the raw forwarder (they process incoming frames) for processor in self.video_processors: try: - await processor.process_video(track, participant.user_id, shared_forwarder=raw_forwarder) + await processor.process_video( + track, participant.user_id, shared_forwarder=raw_forwarder + ) except Exception as e: self.logger.error( f"Error in video processor {type(processor).__name__}: {e}" @@ -654,13 +669,15 @@ async def _process_track(self, track_id: str, track_type: int, participant): # Use raw forwarder for image processors - only if there are image processors if not hasImageProcessers: # No image processors, just keep the connection alive - self.logger.info("No image processors, video processing handled by video processors only") + self.logger.info( + "No image processors, video processing handled by video processors only" + ) return - + # Initialize error tracking counters timeout_errors = 0 consecutive_errors = 0 - + while True: try: # Use the raw forwarder instead of competing for track.recv() @@ -672,7 +689,6 @@ async def _process_track(self, track_id: str, track_type: int, participant): consecutive_errors = 0 if hasImageProcessers: - img = video_frame.to_image() for processor in self.image_processors: @@ -683,7 +699,6 @@ async def _process_track(self, track_id: str, track_type: int, participant): f"Error in image processor {type(processor).__name__}: {e}" ) - else: self.logger.warning("🎥VDP: Received empty frame") consecutive_errors += 1 @@ -698,14 +713,16 @@ async def _process_track(self, track_id: str, track_type: int, participant): await asyncio.sleep(backoff_delay) # Cleanup and logging - self.logger.info(f"🎥VDP: Video processing loop ended for track {track_id} - timeouts: {timeout_errors}, consecutive_errors: {consecutive_errors}") + self.logger.info( + f"🎥VDP: Video processing loop ended for track {track_id} - timeouts: {timeout_errors}, consecutive_errors: {consecutive_errors}" + ) async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None: """Handle turn detection events.""" # In realtime mode, the LLM handles turn detection, interruption, and responses itself if self.realtime_mode: return - + if isinstance(event, TurnStartedEvent): # Interrupt TTS when user starts speaking (barge-in) if event.speaker_id and event.speaker_id != self.agent_user.id: @@ -730,26 +747,28 @@ async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None self.logger.info( f"👉 Turn ended - participant {event.speaker_id} finished (duration: {event.duration}, confidence: {event.confidence})" ) - + # When turn detection is enabled, trigger LLM response when user's turn ends # This is the signal that the user has finished speaking and expects a response if event.speaker_id and event.speaker_id != self.agent_user.id: # Get the accumulated transcript for this speaker transcript = self._pending_user_transcripts.get(event.speaker_id, "") - + if transcript and transcript.strip(): - self.logger.info(f"🤖 Triggering LLM response after turn ended for {event.speaker_id}") - + self.logger.info( + f"🤖 Triggering LLM response after turn ended for {event.speaker_id}" + ) + # Create participant object if we have metadata participant = None - if hasattr(event, 'custom') and event.custom: + if hasattr(event, "custom") and event.custom: # Try to extract participant info from custom metadata - participant = event.custom.get('participant') - + participant = event.custom.get("participant") + # Trigger LLM response with the complete transcript if self.llm: await self.simple_response(transcript, participant) - + # Clear the pending transcript for this speaker self._pending_user_transcripts[event.speaker_id] = "" @@ -806,12 +825,12 @@ async def _on_transcript(self, event: STTTranscriptEvent | RealtimeTranscriptEve ) self.conversation.complete_message(self._user_conversation_handle) self._user_conversation_handle = None - + # In realtime mode, the LLM handles everything itself (STT, turn detection, responses) # Skip our manual LLM triggering logic if self.realtime_mode: return - + # Determine how to handle LLM triggering based on turn detection if self.turn_detection is not None: # With turn detection: accumulate transcripts and wait for TurnEndedEvent @@ -821,7 +840,7 @@ async def _on_transcript(self, event: STTTranscriptEvent | RealtimeTranscriptEve else: # Append to existing transcript (user might be speaking in chunks) self._pending_user_transcripts[user_id] += " " + event.text - + self.logger.debug( f"📝 Accumulated transcript for {user_id} (waiting for turn end): " f"{self._pending_user_transcripts[user_id][:100]}..." @@ -830,21 +849,21 @@ async def _on_transcript(self, event: STTTranscriptEvent | RealtimeTranscriptEve # Without turn detection: trigger LLM immediately on transcript completion # This is the traditional STT -> LLM flow if self.llm: - self.logger.info("🤖 Triggering LLM response immediately (no turn detection)") - + self.logger.info( + "🤖 Triggering LLM response immediately (no turn detection)" + ) + # Get participant from event metadata participant = None if hasattr(event, "user_metadata"): participant = event.user_metadata - + await self.simple_response(event.text, participant) async def _on_stt_error(self, error): """Handle STT service errors.""" self.logger.error(f"❌ STT Error: {error}") - - @property def realtime_mode(self) -> bool: """Check if the agent is in Realtime mode. @@ -869,8 +888,7 @@ def publish_audio(self) -> bool: @property def publish_video(self) -> bool: - """Whether the agent should publish an outbound video track. - """ + """Whether the agent should publish an outbound video track.""" return len(self.video_publishers) > 0 def _needs_audio_or_video_input(self) -> bool: @@ -1000,7 +1018,9 @@ def _prepare_rtc(self): else: framerate = 48000 stereo = True # Default to stereo for WebRTC - self._audio_track = self.edge.create_audio_track(framerate=framerate, stereo=stereo) + self._audio_track = self.edge.create_audio_track( + framerate=framerate, stereo=stereo + ) if self.tts: self.tts.set_output_track(self._audio_track) @@ -1012,7 +1032,6 @@ def _prepare_rtc(self): self._video_track = video_publisher.publish_video_track() self.logger.info("🎥 Video track initialized from video publisher") - def _truncate_for_logging(self, obj, max_length=200): """Truncate object string representation for logging to prevent spam.""" obj_str = str(obj) diff --git a/agents-core/vision_agents/core/agents/conversation.py b/agents-core/vision_agents/core/agents/conversation.py index 3bf3460f..7aae71c3 100644 --- a/agents-core/vision_agents/core/agents/conversation.py +++ b/agents-core/vision_agents/core/agents/conversation.py @@ -9,6 +9,7 @@ logger = logging.getLogger(__name__) + @dataclass class Message: """A single utterance or assistant message within a conversation. @@ -21,6 +22,7 @@ class Message: user_id: Logical user identifier associated with the message. id: Unique message identifier (auto-generated if not provided). """ + content: str original: Optional[Any] = None # the original openai, claude or gemini message timestamp: Optional[datetime.datetime] = None @@ -36,23 +38,24 @@ def __post_init__(self): @dataclass class StreamHandle: """Handle for managing a streaming message. - + This lightweight object is returned when starting a streaming message and must be passed to subsequent update operations. It encapsulates the message ID and user ID, preventing accidental cross-contamination between concurrent streams. - + Example: # Start a streaming message handle = conversation.start_streaming_message(role="assistant") - + # Update the message using the handle conversation.append_to_message(handle, "Hello") conversation.append_to_message(handle, " world!") - + # Complete the message conversation.complete_message(handle) """ + message_id: str user_id: str @@ -75,50 +78,61 @@ def __init__( @abstractmethod def add_message(self, message: Message, completed: bool = True): """Add a message to the conversation. - + Args: message: The Message object to add - completed: If True, mark the message as completed (not generating). + completed: If True, mark the message as completed (not generating). If False, mark as still generating. Defaults to True. - + Returns: The result of the add operation (implementation-specific) """ ... - + @abstractmethod - def update_message(self, message_id: str, input_text: str, user_id: str, replace_content: bool, completed: bool): + def update_message( + self, + message_id: str, + input_text: str, + user_id: str, + replace_content: bool, + completed: bool, + ): """Update an existing message or create a new one if not found. - + Args: message_id: The ID of the message to update input_text: The text content to set or append user_id: The ID of the user who owns the message replace_content: If True, replace the entire message content. If False, append to existing content. completed: If True, mark the message as completed (not generating). If False, mark as still generating. - + Returns: The result of the update operation (implementation-specific) """ ... - + # Streaming message convenience methods - def start_streaming_message(self, role: str = "assistant", user_id: Optional[str] = None, - initial_content: str = "") -> StreamHandle: + def start_streaming_message( + self, + role: str = "assistant", + user_id: Optional[str] = None, + initial_content: str = "", + ) -> StreamHandle: """Start a new streaming message and return a handle for subsequent operations. - + This method simplifies the management of streaming messages by returning a handle that encapsulates the message ID and user ID. Use the handle with append_to_message, replace_message, and complete_message methods. - + Args: role: The role of the message sender (default: "assistant") user_id: The ID of the user (default: same as role) initial_content: Initial content for the message (default: empty string) - + Returns: StreamHandle: A handle to use for subsequent operations on this message - + Example: # Simple usage handle = conversation.start_streaming_message() @@ -126,15 +140,15 @@ def start_streaming_message(self, role: str = "assistant", user_id: Optional[str conversation.replace_message(handle, "Here's the answer: ") conversation.append_to_message(handle, "42") conversation.complete_message(handle) - + # Multiple concurrent streams user_handle = conversation.start_streaming_message(role="user", user_id="user123") assistant_handle = conversation.start_streaming_message(role="assistant") - + # Update both independently conversation.append_to_message(user_handle, "Hello") conversation.append_to_message(assistant_handle, "Hi there!") - + # Complete in any order conversation.complete_message(user_handle) conversation.complete_message(assistant_handle) @@ -144,7 +158,7 @@ def start_streaming_message(self, role: str = "assistant", user_id: Optional[str content=initial_content, role=role, user_id=user_id or role, - id=None # Will be assigned during add + id=None, # Will be assigned during add ) self.add_message(message, completed=False) # The message now has an ID assigned by the add_message flow @@ -154,10 +168,10 @@ def start_streaming_message(self, role: str = "assistant", user_id: Optional[str assert added_message.id is not None, "Message ID should be set by add_message" assert added_message.user_id is not None, "User ID should be set by add_message" return StreamHandle(message_id=added_message.id, user_id=added_message.user_id) - + def append_to_message(self, handle: StreamHandle, text: str): """Append text to a streaming message identified by the handle. - + Args: handle: The StreamHandle returned by start_streaming_message text: Text to append to the message @@ -167,12 +181,12 @@ def append_to_message(self, handle: StreamHandle, text: str): input_text=text, user_id=handle.user_id, replace_content=False, - completed=False + completed=False, ) - + def replace_message(self, handle: StreamHandle, text: str): """Replace the content of a streaming message identified by the handle. - + Args: handle: The StreamHandle returned by start_streaming_message text: Text to replace the message content with @@ -182,18 +196,20 @@ def replace_message(self, handle: StreamHandle, text: str): input_text=text, user_id=handle.user_id, replace_content=True, - completed=False + completed=False, ) - + def complete_message(self, handle: StreamHandle): """Mark a streaming message as completed. - + Args: handle: The StreamHandle returned by start_streaming_message """ # We need to find the message to get its current content # so we can set completed without changing the content - message = next((msg for msg in self.messages if msg.id == handle.message_id), None) + message = next( + (msg for msg in self.messages if msg.id == handle.message_id), None + ) if message: # Use replace mode with the current content to avoid space issues self.update_message( @@ -201,7 +217,7 @@ def complete_message(self, handle: StreamHandle): input_text=message.content, user_id=handle.user_id, replace_content=True, - completed=True + completed=True, ) @@ -240,7 +256,14 @@ def add_message(self, message: Message, completed: bool = True): # In-memory conversation doesn't need to handle completed flag return None - def update_message(self, message_id: str, input_text: str, user_id: str, replace_content: bool, completed: bool): + def update_message( + self, + message_id: str, + input_text: str, + user_id: str, + replace_content: bool, + completed: bool, + ): """Update or create a message in-memory. If the message is not found, a new one is created with the given id. @@ -254,10 +277,15 @@ def update_message(self, message_id: str, input_text: str, user_id: str, replace """ # Find the message by id message = self.lookup(message_id) - + if message is None: logger.info(f"message {message_id} not found, create one instead") - return self.add_message(Message(user_id=user_id, id=message_id, content=input_text, original=None), completed=completed) + return self.add_message( + Message( + user_id=user_id, id=message_id, content=input_text, original=None + ), + completed=completed, + ) if replace_content: message.content = input_text @@ -266,4 +294,3 @@ def update_message(self, message_id: str, input_text: str, user_id: str, replace # In-memory conversation just updates the message, no external API call return None - diff --git a/agents-core/vision_agents/core/agents/events.py b/agents-core/vision_agents/core/agents/events.py index 004247f1..26608084 100644 --- a/agents-core/vision_agents/core/agents/events.py +++ b/agents-core/vision_agents/core/agents/events.py @@ -6,7 +6,8 @@ @dataclass class AgentSayEvent(PluginBaseEvent): """Event emitted when the agent wants to say something.""" - type: str = field(default='agent.say', init=False) + + type: str = field(default="agent.say", init=False) text: str = "" user_id: Optional[str] = None metadata: Optional[Dict[str, Any]] = None @@ -19,7 +20,8 @@ def __post_init__(self): @dataclass class AgentSayStartedEvent(PluginBaseEvent): """Event emitted when agent speech synthesis starts.""" - type: str = field(default='agent.say_started', init=False) + + type: str = field(default="agent.say_started", init=False) text: str = "" user_id: Optional[str] = None synthesis_id: Optional[str] = None @@ -28,7 +30,8 @@ class AgentSayStartedEvent(PluginBaseEvent): @dataclass class AgentSayCompletedEvent(PluginBaseEvent): """Event emitted when agent speech synthesis completes.""" - type: str = field(default='agent.say_completed', init=False) + + type: str = field(default="agent.say_completed", init=False) text: str = "" user_id: Optional[str] = None synthesis_id: Optional[str] = None @@ -38,7 +41,8 @@ class AgentSayCompletedEvent(PluginBaseEvent): @dataclass class AgentSayErrorEvent(PluginBaseEvent): """Event emitted when agent speech synthesis encounters an error.""" - type: str = field(default='agent.say_error', init=False) + + type: str = field(default="agent.say_error", init=False) text: str = "" user_id: Optional[str] = None error: Optional[Exception] = None diff --git a/agents-core/vision_agents/core/cli.py b/agents-core/vision_agents/core/cli.py index dfe2cf47..4feb7671 100644 --- a/agents-core/vision_agents/core/cli.py +++ b/agents-core/vision_agents/core/cli.py @@ -56,7 +56,6 @@ async def start_dispatcher( await agent_func() - logger.info("🔚 Stream Agents dispatcher stopped") diff --git a/agents-core/vision_agents/core/edge/edge_transport.py b/agents-core/vision_agents/core/edge/edge_transport.py index 073728cf..c86d8c3c 100644 --- a/agents-core/vision_agents/core/edge/edge_transport.py +++ b/agents-core/vision_agents/core/edge/edge_transport.py @@ -1,6 +1,7 @@ """ Abstraction for stream vs other services here """ + import abc from typing import TYPE_CHECKING, Any, Optional @@ -11,7 +12,6 @@ from vision_agents.core.edge.types import User if TYPE_CHECKING: - pass @@ -55,6 +55,7 @@ async def create_conversation(self, call: Any, user: User, instructions): pass @abc.abstractmethod - def add_track_subscriber(self, track_id: str) -> Optional[aiortc.mediastreams.MediaStreamTrack]: + def add_track_subscriber( + self, track_id: str + ) -> Optional[aiortc.mediastreams.MediaStreamTrack]: pass - diff --git a/agents-core/vision_agents/core/edge/events.py b/agents-core/vision_agents/core/edge/events.py index 99159b68..62d0ff69 100644 --- a/agents-core/vision_agents/core/edge/events.py +++ b/agents-core/vision_agents/core/edge/events.py @@ -8,7 +8,8 @@ @dataclass class AudioReceivedEvent(PluginBaseEvent): """Event emitted when audio is received from a participant.""" - type: str = field(default='plugin.edge.audio_received', init=False) + + type: str = field(default="plugin.edge.audio_received", init=False) pcm_data: Optional[PcmData] = None participant: Optional[Any] = None @@ -16,7 +17,8 @@ class AudioReceivedEvent(PluginBaseEvent): @dataclass class TrackAddedEvent(PluginBaseEvent): """Event emitted when a track is added to the call.""" - type: str = field(default='plugin.edge.track_added', init=False) + + type: str = field(default="plugin.edge.track_added", init=False) track_id: Optional[str] = None track_type: Optional[int] = None user: Optional[Any] = None @@ -25,7 +27,8 @@ class TrackAddedEvent(PluginBaseEvent): @dataclass class TrackRemovedEvent(PluginBaseEvent): """Event emitted when a track is removed from the call.""" - type: str = field(default='plugin.edge.track_removed', init=False) + + type: str = field(default="plugin.edge.track_removed", init=False) track_id: Optional[str] = None track_type: Optional[int] = None user: Optional[Any] = None @@ -34,6 +37,7 @@ class TrackRemovedEvent(PluginBaseEvent): @dataclass class CallEndedEvent(PluginBaseEvent): """Event emitted when a call ends.""" - type: str = field(default='plugin.edge.call_ended', init=False) + + type: str = field(default="plugin.edge.call_ended", init=False) args: Optional[tuple] = None kwargs: Optional[dict] = None diff --git a/agents-core/vision_agents/core/edge/sfu_events.py b/agents-core/vision_agents/core/edge/sfu_events.py index 9556ac76..306779f0 100644 --- a/agents-core/vision_agents/core/edge/sfu_events.py +++ b/agents-core/vision_agents/core/edge/sfu_events.py @@ -1,4 +1,5 @@ """Auto-generated SFU event dataclasses. Do not edit manually.""" + # Generated by _generate_sfu_events.py from __future__ import annotations @@ -21,88 +22,106 @@ def _to_dict(message) -> Dict[str, Any]: except Exception: return {} + # ============================================================================== # Message Type Wrappers # These are wrappers for protobuf message types used in events # ============================================================================== + @dataclass class Browser(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.Browser.""" + name: Optional[str] = None version: Optional[str] = None @classmethod - def from_proto(cls, proto_obj) -> 'Browser': + def from_proto(cls, proto_obj) -> "Browser": """Create from protobuf Browser.""" if proto_obj is None: return cls() - return cls( - name=proto_obj.name, - version=proto_obj.version - ) + return cls(name=proto_obj.name, version=proto_obj.version) + @dataclass class CallGrants(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.CallGrants.""" + can_publish_audio: Optional[bool] = None can_publish_video: Optional[bool] = None can_screenshare: Optional[bool] = None @classmethod - def from_proto(cls, proto_obj) -> 'CallGrants': + def from_proto(cls, proto_obj) -> "CallGrants": """Create from protobuf CallGrants.""" if proto_obj is None: return cls() return cls( can_publish_audio=proto_obj.can_publish_audio, can_publish_video=proto_obj.can_publish_video, - can_screenshare=proto_obj.can_screenshare + can_screenshare=proto_obj.can_screenshare, ) + @dataclass class CallState(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.CallState.""" + participants: Optional[List[Participant]] = None started_at: Optional[Any] = None participant_count: Optional[ParticipantCount] = None pins: Optional[List[Pin]] = None @classmethod - def from_proto(cls, proto_obj) -> 'CallState': + def from_proto(cls, proto_obj) -> "CallState": """Create from protobuf CallState.""" if proto_obj is None: return cls() return cls( - participants=[Participant.from_proto(item) for item in proto_obj.participants], - started_at=proto_obj.started_at if proto_obj.HasField('started_at') else None, - participant_count=ParticipantCount.from_proto(proto_obj.participant_count) if proto_obj.HasField('participant_count') else None, - pins=[Pin.from_proto(item) for item in proto_obj.pins] + participants=[ + Participant.from_proto(item) for item in proto_obj.participants + ], + started_at=proto_obj.started_at + if proto_obj.HasField("started_at") + else None, + participant_count=ParticipantCount.from_proto(proto_obj.participant_count) + if proto_obj.HasField("participant_count") + else None, + pins=[Pin.from_proto(item) for item in proto_obj.pins], ) + @dataclass class ClientDetails(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.ClientDetails.""" + sdk: Optional[Sdk] = None os: Optional[OS] = None browser: Optional[Browser] = None device: Optional[Device] = None @classmethod - def from_proto(cls, proto_obj) -> 'ClientDetails': + def from_proto(cls, proto_obj) -> "ClientDetails": """Create from protobuf ClientDetails.""" if proto_obj is None: return cls() return cls( - sdk=Sdk.from_proto(proto_obj.sdk) if proto_obj.HasField('sdk') else None, - os=OS.from_proto(proto_obj.os) if proto_obj.HasField('os') else None, - browser=Browser.from_proto(proto_obj.browser) if proto_obj.HasField('browser') else None, - device=Device.from_proto(proto_obj.device) if proto_obj.HasField('device') else None + sdk=Sdk.from_proto(proto_obj.sdk) if proto_obj.HasField("sdk") else None, + os=OS.from_proto(proto_obj.os) if proto_obj.HasField("os") else None, + browser=Browser.from_proto(proto_obj.browser) + if proto_obj.HasField("browser") + else None, + device=Device.from_proto(proto_obj.device) + if proto_obj.HasField("device") + else None, ) + @dataclass class Codec(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.Codec.""" + payload_type: Optional[int] = None name: Optional[str] = None clock_rate: Optional[int] = None @@ -110,7 +129,7 @@ class Codec(DataClassJsonMixin): fmtp: Optional[str] = None @classmethod - def from_proto(cls, proto_obj) -> 'Codec': + def from_proto(cls, proto_obj) -> "Codec": """Create from protobuf Codec.""" if proto_obj is None: return cls() @@ -119,95 +138,102 @@ def from_proto(cls, proto_obj) -> 'Codec': name=proto_obj.name, clock_rate=proto_obj.clock_rate, encoding_parameters=proto_obj.encoding_parameters, - fmtp=proto_obj.fmtp + fmtp=proto_obj.fmtp, ) + @dataclass class Device(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.Device.""" + name: Optional[str] = None version: Optional[str] = None @classmethod - def from_proto(cls, proto_obj) -> 'Device': + def from_proto(cls, proto_obj) -> "Device": """Create from protobuf Device.""" if proto_obj is None: return cls() - return cls( - name=proto_obj.name, - version=proto_obj.version - ) + return cls(name=proto_obj.name, version=proto_obj.version) + @dataclass class Error(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.Error. - + Enum fields (use values from models_pb2): - code: ErrorCode """ + code: Optional[int] = None message: Optional[str] = None should_retry: Optional[bool] = None @classmethod - def from_proto(cls, proto_obj) -> 'Error': + def from_proto(cls, proto_obj) -> "Error": """Create from protobuf Error.""" if proto_obj is None: return cls() return cls( code=proto_obj.code, message=proto_obj.message, - should_retry=proto_obj.should_retry + should_retry=proto_obj.should_retry, ) + @dataclass class ICETrickle(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.ICETrickle. - + Enum fields (use values from models_pb2): - peer_type: PeerType """ + peer_type: Optional[int] = None ice_candidate: Optional[str] = None session_id: Optional[str] = None @classmethod - def from_proto(cls, proto_obj) -> 'ICETrickle': + def from_proto(cls, proto_obj) -> "ICETrickle": """Create from protobuf ICETrickle.""" if proto_obj is None: return cls() return cls( peer_type=proto_obj.peer_type, ice_candidate=proto_obj.ice_candidate, - session_id=proto_obj.session_id + session_id=proto_obj.session_id, ) + @dataclass class OS(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.OS.""" + name: Optional[str] = None version: Optional[str] = None architecture: Optional[str] = None @classmethod - def from_proto(cls, proto_obj) -> 'OS': + def from_proto(cls, proto_obj) -> "OS": """Create from protobuf OS.""" if proto_obj is None: return cls() return cls( name=proto_obj.name, version=proto_obj.version, - architecture=proto_obj.architecture + architecture=proto_obj.architecture, ) + @dataclass class Participant(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.Participant. - + Enum fields (use values from models_pb2): - published_tracks: TrackType - connection_quality: ConnectionQuality """ + user_id: Optional[str] = None session_id: Optional[str] = None published_tracks: Optional[List[int]] = None @@ -223,7 +249,7 @@ class Participant(DataClassJsonMixin): roles: Optional[List[str]] = None @classmethod - def from_proto(cls, proto_obj) -> 'Participant': + def from_proto(cls, proto_obj) -> "Participant": """Create from protobuf Participant.""" if proto_obj is None: return cls() @@ -231,7 +257,7 @@ def from_proto(cls, proto_obj) -> 'Participant': user_id=proto_obj.user_id, session_id=proto_obj.session_id, published_tracks=list(proto_obj.published_tracks), - joined_at=proto_obj.joined_at if proto_obj.HasField('joined_at') else None, + joined_at=proto_obj.joined_at if proto_obj.HasField("joined_at") else None, track_lookup_prefix=proto_obj.track_lookup_prefix, connection_quality=proto_obj.connection_quality, is_speaking=proto_obj.is_speaking, @@ -239,49 +265,49 @@ def from_proto(cls, proto_obj) -> 'Participant': audio_level=proto_obj.audio_level, name=proto_obj.name, image=proto_obj.image, - custom=proto_obj.custom if proto_obj.HasField('custom') else None, - roles=list(proto_obj.roles) + custom=proto_obj.custom if proto_obj.HasField("custom") else None, + roles=list(proto_obj.roles), ) + @dataclass class ParticipantCount(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.ParticipantCount.""" + total: Optional[int] = None anonymous: Optional[int] = None @classmethod - def from_proto(cls, proto_obj) -> 'ParticipantCount': + def from_proto(cls, proto_obj) -> "ParticipantCount": """Create from protobuf ParticipantCount.""" if proto_obj is None: return cls() - return cls( - total=proto_obj.total, - anonymous=proto_obj.anonymous - ) + return cls(total=proto_obj.total, anonymous=proto_obj.anonymous) + @dataclass class Pin(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.Pin.""" + user_id: Optional[str] = None session_id: Optional[str] = None @classmethod - def from_proto(cls, proto_obj) -> 'Pin': + def from_proto(cls, proto_obj) -> "Pin": """Create from protobuf Pin.""" if proto_obj is None: return cls() - return cls( - user_id=proto_obj.user_id, - session_id=proto_obj.session_id - ) + return cls(user_id=proto_obj.user_id, session_id=proto_obj.session_id) + @dataclass class PublishOption(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.PublishOption. - + Enum fields (use values from models_pb2): - track_type: TrackType """ + track_type: Optional[int] = None codec: Optional[Codec] = None bitrate: Optional[int] = None @@ -293,36 +319,42 @@ class PublishOption(DataClassJsonMixin): use_single_layer: Optional[bool] = None @classmethod - def from_proto(cls, proto_obj) -> 'PublishOption': + def from_proto(cls, proto_obj) -> "PublishOption": """Create from protobuf PublishOption.""" if proto_obj is None: return cls() return cls( track_type=proto_obj.track_type, - codec=Codec.from_proto(proto_obj.codec) if proto_obj.HasField('codec') else None, + codec=Codec.from_proto(proto_obj.codec) + if proto_obj.HasField("codec") + else None, bitrate=proto_obj.bitrate, fps=proto_obj.fps, max_spatial_layers=proto_obj.max_spatial_layers, max_temporal_layers=proto_obj.max_temporal_layers, - video_dimension=VideoDimension.from_proto(proto_obj.video_dimension) if proto_obj.HasField('video_dimension') else None, + video_dimension=VideoDimension.from_proto(proto_obj.video_dimension) + if proto_obj.HasField("video_dimension") + else None, id=proto_obj.id, - use_single_layer=proto_obj.use_single_layer + use_single_layer=proto_obj.use_single_layer, ) + @dataclass class Sdk(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.Sdk. - + Enum fields (use values from models_pb2): - type: SdkType """ + type: Optional[int] = None major: Optional[str] = None minor: Optional[str] = None patch: Optional[str] = None @classmethod - def from_proto(cls, proto_obj) -> 'Sdk': + def from_proto(cls, proto_obj) -> "Sdk": """Create from protobuf Sdk.""" if proto_obj is None: return cls() @@ -330,36 +362,40 @@ def from_proto(cls, proto_obj) -> 'Sdk': type=proto_obj.type, major=proto_obj.major, minor=proto_obj.minor, - patch=proto_obj.patch + patch=proto_obj.patch, ) + @dataclass class SubscribeOption(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.SubscribeOption. - + Enum fields (use values from models_pb2): - track_type: TrackType """ + track_type: Optional[int] = None codecs: Optional[List[Codec]] = None @classmethod - def from_proto(cls, proto_obj) -> 'SubscribeOption': + def from_proto(cls, proto_obj) -> "SubscribeOption": """Create from protobuf SubscribeOption.""" if proto_obj is None: return cls() return cls( track_type=proto_obj.track_type, - codecs=[Codec.from_proto(item) for item in proto_obj.codecs] + codecs=[Codec.from_proto(item) for item in proto_obj.codecs], ) + @dataclass class TrackInfo(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.TrackInfo. - + Enum fields (use values from models_pb2): - track_type: TrackType """ + track_id: Optional[str] = None track_type: Optional[int] = None layers: Optional[List[VideoLayer]] = None @@ -372,7 +408,7 @@ class TrackInfo(DataClassJsonMixin): publish_option_id: Optional[int] = None @classmethod - def from_proto(cls, proto_obj) -> 'TrackInfo': + def from_proto(cls, proto_obj) -> "TrackInfo": """Create from protobuf TrackInfo.""" if proto_obj is None: return cls() @@ -385,33 +421,36 @@ def from_proto(cls, proto_obj) -> 'TrackInfo': stereo=proto_obj.stereo, red=proto_obj.red, muted=proto_obj.muted, - codec=Codec.from_proto(proto_obj.codec) if proto_obj.HasField('codec') else None, - publish_option_id=proto_obj.publish_option_id + codec=Codec.from_proto(proto_obj.codec) + if proto_obj.HasField("codec") + else None, + publish_option_id=proto_obj.publish_option_id, ) + @dataclass class VideoDimension(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.VideoDimension.""" + width: Optional[int] = None height: Optional[int] = None @classmethod - def from_proto(cls, proto_obj) -> 'VideoDimension': + def from_proto(cls, proto_obj) -> "VideoDimension": """Create from protobuf VideoDimension.""" if proto_obj is None: return cls() - return cls( - width=proto_obj.width, - height=proto_obj.height - ) + return cls(width=proto_obj.width, height=proto_obj.height) + @dataclass class VideoLayer(DataClassJsonMixin): """Wrapper for stream.video.sfu.models.VideoLayer. - + Enum fields (use values from models_pb2): - quality: VideoQuality """ + rid: Optional[str] = None video_dimension: Optional[VideoDimension] = None bitrate: Optional[int] = None @@ -419,16 +458,18 @@ class VideoLayer(DataClassJsonMixin): quality: Optional[int] = None @classmethod - def from_proto(cls, proto_obj) -> 'VideoLayer': + def from_proto(cls, proto_obj) -> "VideoLayer": """Create from protobuf VideoLayer.""" if proto_obj is None: return cls() return cls( rid=proto_obj.rid, - video_dimension=VideoDimension.from_proto(proto_obj.video_dimension) if proto_obj.HasField('video_dimension') else None, + video_dimension=VideoDimension.from_proto(proto_obj.video_dimension) + if proto_obj.HasField("video_dimension") + else None, bitrate=proto_obj.bitrate, fps=proto_obj.fps, - quality=proto_obj.quality + quality=proto_obj.quality, ) @@ -436,9 +477,11 @@ def from_proto(cls, proto_obj) -> 'VideoLayer': # Event Classes # ============================================================================== + @dataclass class AudioLevelEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.AudioLevel.""" + type: str = field(default="stream.video.sfu.event.AudioLevel", init=False) payload: Optional[events_pb2.AudioLevel] = field(default=None, repr=False) @@ -447,21 +490,21 @@ def user_id(self) -> Optional[str]: """Access user_id field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'user_id', None) + return getattr(self.payload, "user_id", None) @property def level(self) -> Optional[float]: """Access level field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'level', None) + return getattr(self.payload, "level", None) @property def is_speaking(self) -> Optional[bool]: """Access is_speaking field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'is_speaking', None) + return getattr(self.payload, "is_speaking", None) @classmethod def from_proto(cls, proto_obj: events_pb2.AudioLevel, **extra): @@ -478,11 +521,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class AudioLevelChangedEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.AudioLevelChanged.""" + type: str = field(default="stream.video.sfu.event.AudioLevelChanged", init=False) payload: Optional[events_pb2.AudioLevelChanged] = field(default=None, repr=False) @@ -491,7 +538,7 @@ def audio_levels(self) -> Optional[List[Any]]: """Access audio_levels field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'audio_levels', None) + return getattr(self.payload, "audio_levels", None) @classmethod def from_proto(cls, proto_obj: events_pb2.AudioLevelChanged, **extra): @@ -508,11 +555,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class AudioSenderEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.AudioSender.""" + type: str = field(default="stream.video.sfu.event.AudioSender", init=False) payload: Optional[events_pb2.AudioSender] = field(default=None, repr=False) @@ -521,7 +572,7 @@ def codec(self) -> Optional[Codec]: """Access codec field from the protobuf payload.""" if self.payload is None: return None - proto_val = getattr(self.payload, 'codec', None) + proto_val = getattr(self.payload, "codec", None) return Codec.from_proto(proto_val) if proto_val is not None else None @property @@ -529,14 +580,14 @@ def track_type(self) -> Optional[int]: """Access track_type field from the protobuf payload. Use models_pb2.TrackType enum.""" if self.payload is None: return None - return getattr(self.payload, 'track_type', None) + return getattr(self.payload, "track_type", None) @property def publish_option_id(self) -> Optional[int]: """Access publish_option_id field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'publish_option_id', None) + return getattr(self.payload, "publish_option_id", None) @classmethod def from_proto(cls, proto_obj: events_pb2.AudioSender, **extra): @@ -553,11 +604,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class CallEndedEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.CallEnded.""" + type: str = field(default="stream.video.sfu.event.CallEnded", init=False) payload: Optional[events_pb2.CallEnded] = field(default=None, repr=False) @@ -566,7 +621,7 @@ def reason(self) -> Optional[int]: """Access reason field from the protobuf payload. Use models_pb2.CallEndedReason enum.""" if self.payload is None: return None - return getattr(self.payload, 'reason', None) + return getattr(self.payload, "reason", None) @classmethod def from_proto(cls, proto_obj: events_pb2.CallEnded, **extra): @@ -583,11 +638,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class CallGrantsUpdatedEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.CallGrantsUpdated.""" + type: str = field(default="stream.video.sfu.event.CallGrantsUpdated", init=False) payload: Optional[events_pb2.CallGrantsUpdated] = field(default=None, repr=False) @@ -596,7 +655,7 @@ def current_grants(self) -> Optional[CallGrants]: """Access current_grants field from the protobuf payload.""" if self.payload is None: return None - proto_val = getattr(self.payload, 'current_grants', None) + proto_val = getattr(self.payload, "current_grants", None) return CallGrants.from_proto(proto_val) if proto_val is not None else None @property @@ -604,7 +663,7 @@ def message(self) -> Optional[str]: """Access message field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'message', None) + return getattr(self.payload, "message", None) @classmethod def from_proto(cls, proto_obj: events_pb2.CallGrantsUpdated, **extra): @@ -621,11 +680,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class ChangePublishOptionsEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.ChangePublishOptions.""" + type: str = field(default="stream.video.sfu.event.ChangePublishOptions", init=False) payload: Optional[events_pb2.ChangePublishOptions] = field(default=None, repr=False) @@ -634,15 +697,19 @@ def publish_options(self) -> Optional[List[PublishOption]]: """Access publish_options field from the protobuf payload.""" if self.payload is None: return None - proto_list = getattr(self.payload, 'publish_options', []) - return [PublishOption.from_proto(item) for item in proto_list] if proto_list else None + proto_list = getattr(self.payload, "publish_options", []) + return ( + [PublishOption.from_proto(item) for item in proto_list] + if proto_list + else None + ) @property def reason(self) -> Optional[str]: """Access reason field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'reason', None) + return getattr(self.payload, "reason", None) @classmethod def from_proto(cls, proto_obj: events_pb2.ChangePublishOptions, **extra): @@ -659,13 +726,21 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class ChangePublishOptionsCompleteEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.ChangePublishOptionsComplete.""" - type: str = field(default="stream.video.sfu.event.ChangePublishOptionsComplete", init=False) - payload: Optional[events_pb2.ChangePublishOptionsComplete] = field(default=None, repr=False) + + type: str = field( + default="stream.video.sfu.event.ChangePublishOptionsComplete", init=False + ) + payload: Optional[events_pb2.ChangePublishOptionsComplete] = field( + default=None, repr=False + ) @classmethod def from_proto(cls, proto_obj: events_pb2.ChangePublishOptionsComplete, **extra): @@ -682,11 +757,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class ChangePublishQualityEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.ChangePublishQuality.""" + type: str = field(default="stream.video.sfu.event.ChangePublishQuality", init=False) payload: Optional[events_pb2.ChangePublishQuality] = field(default=None, repr=False) @@ -695,14 +774,14 @@ def audio_senders(self) -> Optional[List[Any]]: """Access audio_senders field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'audio_senders', None) + return getattr(self.payload, "audio_senders", None) @property def video_senders(self) -> Optional[List[Any]]: """Access video_senders field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'video_senders', None) + return getattr(self.payload, "video_senders", None) @classmethod def from_proto(cls, proto_obj: events_pb2.ChangePublishQuality, **extra): @@ -719,20 +798,28 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class ConnectionQualityChangedEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.ConnectionQualityChanged.""" - type: str = field(default="stream.video.sfu.event.ConnectionQualityChanged", init=False) - payload: Optional[events_pb2.ConnectionQualityChanged] = field(default=None, repr=False) + + type: str = field( + default="stream.video.sfu.event.ConnectionQualityChanged", init=False + ) + payload: Optional[events_pb2.ConnectionQualityChanged] = field( + default=None, repr=False + ) @property def connection_quality_updates(self) -> Optional[List[Any]]: """Access connection_quality_updates field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'connection_quality_updates', None) + return getattr(self.payload, "connection_quality_updates", None) @classmethod def from_proto(cls, proto_obj: events_pb2.ConnectionQualityChanged, **extra): @@ -749,27 +836,35 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class ConnectionQualityInfoEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.ConnectionQualityInfo.""" - type: str = field(default="stream.video.sfu.event.ConnectionQualityInfo", init=False) - payload: Optional[events_pb2.ConnectionQualityInfo] = field(default=None, repr=False) + + type: str = field( + default="stream.video.sfu.event.ConnectionQualityInfo", init=False + ) + payload: Optional[events_pb2.ConnectionQualityInfo] = field( + default=None, repr=False + ) @property def user_id(self) -> Optional[str]: """Access user_id field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'user_id', None) + return getattr(self.payload, "user_id", None) @property def connection_quality(self) -> Optional[int]: """Access connection_quality field from the protobuf payload. Use models_pb2.ConnectionQuality enum.""" if self.payload is None: return None - return getattr(self.payload, 'connection_quality', None) + return getattr(self.payload, "connection_quality", None) @classmethod def from_proto(cls, proto_obj: events_pb2.ConnectionQualityInfo, **extra): @@ -786,20 +881,28 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class DominantSpeakerChangedEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.DominantSpeakerChanged.""" - type: str = field(default="stream.video.sfu.event.DominantSpeakerChanged", init=False) - payload: Optional[events_pb2.DominantSpeakerChanged] = field(default=None, repr=False) + + type: str = field( + default="stream.video.sfu.event.DominantSpeakerChanged", init=False + ) + payload: Optional[events_pb2.DominantSpeakerChanged] = field( + default=None, repr=False + ) @property def user_id(self) -> Optional[str]: """Access user_id field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'user_id', None) + return getattr(self.payload, "user_id", None) @classmethod def from_proto(cls, proto_obj: events_pb2.DominantSpeakerChanged, **extra): @@ -816,11 +919,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class ErrorEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.Error.""" + type: str = field(default="stream.video.sfu.event.Error", init=False) payload: Optional[events_pb2.Error] = field(default=None, repr=False) @@ -829,7 +936,7 @@ def error(self) -> Optional[Error]: """Access error field from the protobuf payload.""" if self.payload is None: return None - proto_val = getattr(self.payload, 'error', None) + proto_val = getattr(self.payload, "error", None) return Error.from_proto(proto_val) if proto_val is not None else None @property @@ -837,7 +944,7 @@ def reconnect_strategy(self) -> Optional[int]: """Access reconnect_strategy field from the protobuf payload. Use models_pb2.WebsocketReconnectStrategy enum.""" if self.payload is None: return None - return getattr(self.payload, 'reconnect_strategy', None) + return getattr(self.payload, "reconnect_strategy", None) @classmethod def from_proto(cls, proto_obj: events_pb2.Error, **extra): @@ -854,11 +961,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class GoAwayEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.GoAway.""" + type: str = field(default="stream.video.sfu.event.GoAway", init=False) payload: Optional[events_pb2.GoAway] = field(default=None, repr=False) @@ -867,7 +978,7 @@ def reason(self) -> Optional[int]: """Access reason field from the protobuf payload. Use models_pb2.GoAwayReason enum.""" if self.payload is None: return None - return getattr(self.payload, 'reason', None) + return getattr(self.payload, "reason", None) @classmethod def from_proto(cls, proto_obj: events_pb2.GoAway, **extra): @@ -884,11 +995,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class HealthCheckRequestEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.HealthCheckRequest.""" + type: str = field(default="stream.video.sfu.event.HealthCheckRequest", init=False) payload: Optional[events_pb2.HealthCheckRequest] = field(default=None, repr=False) @@ -907,11 +1022,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class HealthCheckResponseEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.HealthCheckResponse.""" + type: str = field(default="stream.video.sfu.event.HealthCheckResponse", init=False) payload: Optional[events_pb2.HealthCheckResponse] = field(default=None, repr=False) @@ -920,7 +1039,7 @@ def participant_count(self) -> Optional[ParticipantCount]: """Access participant_count field from the protobuf payload.""" if self.payload is None: return None - proto_val = getattr(self.payload, 'participant_count', None) + proto_val = getattr(self.payload, "participant_count", None) return ParticipantCount.from_proto(proto_val) if proto_val is not None else None @classmethod @@ -938,11 +1057,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class ICERestartEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.ICERestart.""" + type: str = field(default="stream.video.sfu.event.ICERestart", init=False) payload: Optional[events_pb2.ICERestart] = field(default=None, repr=False) @@ -951,7 +1074,7 @@ def peer_type(self) -> Optional[int]: """Access peer_type field from the protobuf payload. Use models_pb2.PeerType enum.""" if self.payload is None: return None - return getattr(self.payload, 'peer_type', None) + return getattr(self.payload, "peer_type", None) @classmethod def from_proto(cls, proto_obj: events_pb2.ICERestart, **extra): @@ -968,11 +1091,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class ICETrickleEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.ICETrickle.""" + type: str = field(default="stream.video.sfu.event.ICETrickle", init=False) payload: Optional[events_pb2.ICETrickle] = field(default=None, repr=False) @@ -981,14 +1108,14 @@ def peer_type(self) -> Optional[int]: """Access peer_type field from the protobuf payload. Use models_pb2.PeerType enum.""" if self.payload is None: return None - return getattr(self.payload, 'peer_type', None) + return getattr(self.payload, "peer_type", None) @property def ice_candidate(self) -> Optional[str]: """Access ice_candidate field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'ice_candidate', None) + return getattr(self.payload, "ice_candidate", None) @classmethod def from_proto(cls, proto_obj: events_pb2.ICETrickle, **extra): @@ -1005,20 +1132,28 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class InboundStateNotificationEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.InboundStateNotification.""" - type: str = field(default="stream.video.sfu.event.InboundStateNotification", init=False) - payload: Optional[events_pb2.InboundStateNotification] = field(default=None, repr=False) + + type: str = field( + default="stream.video.sfu.event.InboundStateNotification", init=False + ) + payload: Optional[events_pb2.InboundStateNotification] = field( + default=None, repr=False + ) @property def inbound_video_states(self) -> Optional[List[Any]]: """Access inbound_video_states field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'inbound_video_states', None) + return getattr(self.payload, "inbound_video_states", None) @classmethod def from_proto(cls, proto_obj: events_pb2.InboundStateNotification, **extra): @@ -1035,11 +1170,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class InboundVideoStateEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.InboundVideoState.""" + type: str = field(default="stream.video.sfu.event.InboundVideoState", init=False) payload: Optional[events_pb2.InboundVideoState] = field(default=None, repr=False) @@ -1048,21 +1187,21 @@ def user_id(self) -> Optional[str]: """Access user_id field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'user_id', None) + return getattr(self.payload, "user_id", None) @property def track_type(self) -> Optional[int]: """Access track_type field from the protobuf payload. Use models_pb2.TrackType enum.""" if self.payload is None: return None - return getattr(self.payload, 'track_type', None) + return getattr(self.payload, "track_type", None) @property def paused(self) -> Optional[bool]: """Access paused field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'paused', None) + return getattr(self.payload, "paused", None) @classmethod def from_proto(cls, proto_obj: events_pb2.InboundVideoState, **extra): @@ -1079,11 +1218,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class JoinRequestEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.JoinRequest.""" + type: str = field(default="stream.video.sfu.event.JoinRequest", init=False) payload: Optional[events_pb2.JoinRequest] = field(default=None, repr=False) @@ -1092,28 +1235,28 @@ def token(self) -> Optional[str]: """Access token field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'token', None) + return getattr(self.payload, "token", None) @property def subscriber_sdp(self) -> Optional[str]: """Access subscriber_sdp field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'subscriber_sdp', None) + return getattr(self.payload, "subscriber_sdp", None) @property def publisher_sdp(self) -> Optional[str]: """Access publisher_sdp field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'publisher_sdp', None) + return getattr(self.payload, "publisher_sdp", None) @property def client_details(self) -> Optional[ClientDetails]: """Access client_details field from the protobuf payload.""" if self.payload is None: return None - proto_val = getattr(self.payload, 'client_details', None) + proto_val = getattr(self.payload, "client_details", None) return ClientDetails.from_proto(proto_val) if proto_val is not None else None @property @@ -1121,44 +1264,52 @@ def migration(self) -> Optional[Any]: """Access migration field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'migration', None) + return getattr(self.payload, "migration", None) @property def fast_reconnect(self) -> Optional[bool]: """Access fast_reconnect field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'fast_reconnect', None) + return getattr(self.payload, "fast_reconnect", None) @property def reconnect_details(self) -> Optional[Any]: """Access reconnect_details field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'reconnect_details', None) + return getattr(self.payload, "reconnect_details", None) @property def preferred_publish_options(self) -> Optional[List[PublishOption]]: """Access preferred_publish_options field from the protobuf payload.""" if self.payload is None: return None - proto_list = getattr(self.payload, 'preferred_publish_options', []) - return [PublishOption.from_proto(item) for item in proto_list] if proto_list else None + proto_list = getattr(self.payload, "preferred_publish_options", []) + return ( + [PublishOption.from_proto(item) for item in proto_list] + if proto_list + else None + ) @property def preferred_subscribe_options(self) -> Optional[List[SubscribeOption]]: """Access preferred_subscribe_options field from the protobuf payload.""" if self.payload is None: return None - proto_list = getattr(self.payload, 'preferred_subscribe_options', []) - return [SubscribeOption.from_proto(item) for item in proto_list] if proto_list else None + proto_list = getattr(self.payload, "preferred_subscribe_options", []) + return ( + [SubscribeOption.from_proto(item) for item in proto_list] + if proto_list + else None + ) @property def capabilities(self) -> Optional[List[int]]: """Access capabilities field from the protobuf payload. Use models_pb2.ClientCapability enum.""" if self.payload is None: return None - return getattr(self.payload, 'capabilities', None) + return getattr(self.payload, "capabilities", None) @classmethod def from_proto(cls, proto_obj: events_pb2.JoinRequest, **extra): @@ -1175,11 +1326,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class JoinResponseEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.JoinResponse.""" + type: str = field(default="stream.video.sfu.event.JoinResponse", init=False) payload: Optional[events_pb2.JoinResponse] = field(default=None, repr=False) @@ -1188,7 +1343,7 @@ def call_state(self) -> Optional[CallState]: """Access call_state field from the protobuf payload.""" if self.payload is None: return None - proto_val = getattr(self.payload, 'call_state', None) + proto_val = getattr(self.payload, "call_state", None) return CallState.from_proto(proto_val) if proto_val is not None else None @property @@ -1196,22 +1351,26 @@ def reconnected(self) -> Optional[bool]: """Access reconnected field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'reconnected', None) + return getattr(self.payload, "reconnected", None) @property def fast_reconnect_deadline_seconds(self) -> Optional[int]: """Access fast_reconnect_deadline_seconds field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'fast_reconnect_deadline_seconds', None) + return getattr(self.payload, "fast_reconnect_deadline_seconds", None) @property def publish_options(self) -> Optional[List[PublishOption]]: """Access publish_options field from the protobuf payload.""" if self.payload is None: return None - proto_list = getattr(self.payload, 'publish_options', []) - return [PublishOption.from_proto(item) for item in proto_list] if proto_list else None + proto_list = getattr(self.payload, "publish_options", []) + return ( + [PublishOption.from_proto(item) for item in proto_list] + if proto_list + else None + ) @classmethod def from_proto(cls, proto_obj: events_pb2.JoinResponse, **extra): @@ -1228,11 +1387,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class LeaveCallRequestEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.LeaveCallRequest.""" + type: str = field(default="stream.video.sfu.event.LeaveCallRequest", init=False) payload: Optional[events_pb2.LeaveCallRequest] = field(default=None, repr=False) @@ -1241,7 +1404,7 @@ def reason(self) -> Optional[str]: """Access reason field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'reason', None) + return getattr(self.payload, "reason", None) @classmethod def from_proto(cls, proto_obj: events_pb2.LeaveCallRequest, **extra): @@ -1258,11 +1421,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class MigrationEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.Migration.""" + type: str = field(default="stream.video.sfu.event.Migration", init=False) payload: Optional[events_pb2.Migration] = field(default=None, repr=False) @@ -1271,22 +1438,24 @@ def from_sfu_id(self) -> Optional[str]: """Access from_sfu_id field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'from_sfu_id', None) + return getattr(self.payload, "from_sfu_id", None) @property def announced_tracks(self) -> Optional[List[TrackInfo]]: """Access announced_tracks field from the protobuf payload.""" if self.payload is None: return None - proto_list = getattr(self.payload, 'announced_tracks', []) - return [TrackInfo.from_proto(item) for item in proto_list] if proto_list else None + proto_list = getattr(self.payload, "announced_tracks", []) + return ( + [TrackInfo.from_proto(item) for item in proto_list] if proto_list else None + ) @property def subscriptions(self) -> Optional[List[Any]]: """Access subscriptions field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'subscriptions', None) + return getattr(self.payload, "subscriptions", None) @classmethod def from_proto(cls, proto_obj: events_pb2.Migration, **extra): @@ -1303,11 +1472,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class ParticipantJoinedEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.ParticipantJoined.""" + type: str = field(default="stream.video.sfu.event.ParticipantJoined", init=False) payload: Optional[events_pb2.ParticipantJoined] = field(default=None, repr=False) @@ -1316,14 +1489,14 @@ def call_cid(self) -> Optional[str]: """Access call_cid field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'call_cid', None) + return getattr(self.payload, "call_cid", None) @property def participant(self) -> Optional[Participant]: """Access participant field from the protobuf payload.""" if self.payload is None: return None - proto_val = getattr(self.payload, 'participant', None) + proto_val = getattr(self.payload, "participant", None) return Participant.from_proto(proto_val) if proto_val is not None else None @classmethod @@ -1341,11 +1514,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class ParticipantLeftEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.ParticipantLeft.""" + type: str = field(default="stream.video.sfu.event.ParticipantLeft", init=False) payload: Optional[events_pb2.ParticipantLeft] = field(default=None, repr=False) @@ -1354,14 +1531,14 @@ def call_cid(self) -> Optional[str]: """Access call_cid field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'call_cid', None) + return getattr(self.payload, "call_cid", None) @property def participant(self) -> Optional[Participant]: """Access participant field from the protobuf payload.""" if self.payload is None: return None - proto_val = getattr(self.payload, 'participant', None) + proto_val = getattr(self.payload, "participant", None) return Participant.from_proto(proto_val) if proto_val is not None else None @classmethod @@ -1379,13 +1556,21 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class ParticipantMigrationCompleteEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.ParticipantMigrationComplete.""" - type: str = field(default="stream.video.sfu.event.ParticipantMigrationComplete", init=False) - payload: Optional[events_pb2.ParticipantMigrationComplete] = field(default=None, repr=False) + + type: str = field( + default="stream.video.sfu.event.ParticipantMigrationComplete", init=False + ) + payload: Optional[events_pb2.ParticipantMigrationComplete] = field( + default=None, repr=False + ) @classmethod def from_proto(cls, proto_obj: events_pb2.ParticipantMigrationComplete, **extra): @@ -1402,11 +1587,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class ParticipantUpdatedEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.ParticipantUpdated.""" + type: str = field(default="stream.video.sfu.event.ParticipantUpdated", init=False) payload: Optional[events_pb2.ParticipantUpdated] = field(default=None, repr=False) @@ -1415,14 +1604,14 @@ def call_cid(self) -> Optional[str]: """Access call_cid field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'call_cid', None) + return getattr(self.payload, "call_cid", None) @property def participant(self) -> Optional[Participant]: """Access participant field from the protobuf payload.""" if self.payload is None: return None - proto_val = getattr(self.payload, 'participant', None) + proto_val = getattr(self.payload, "participant", None) return Participant.from_proto(proto_val) if proto_val is not None else None @classmethod @@ -1440,11 +1629,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class PinsChangedEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.PinsChanged.""" + type: str = field(default="stream.video.sfu.event.PinsChanged", init=False) payload: Optional[events_pb2.PinsChanged] = field(default=None, repr=False) @@ -1453,7 +1646,7 @@ def pins(self) -> Optional[List[Pin]]: """Access pins field from the protobuf payload.""" if self.payload is None: return None - proto_list = getattr(self.payload, 'pins', []) + proto_list = getattr(self.payload, "pins", []) return [Pin.from_proto(item) for item in proto_list] if proto_list else None @classmethod @@ -1471,11 +1664,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class PublisherAnswerEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.PublisherAnswer.""" + type: str = field(default="stream.video.sfu.event.PublisherAnswer", init=False) payload: Optional[events_pb2.PublisherAnswer] = field(default=None, repr=False) @@ -1484,7 +1681,7 @@ def sdp(self) -> Optional[str]: """Access sdp field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'sdp', None) + return getattr(self.payload, "sdp", None) @classmethod def from_proto(cls, proto_obj: events_pb2.PublisherAnswer, **extra): @@ -1501,11 +1698,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class ReconnectDetailsEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.ReconnectDetails.""" + type: str = field(default="stream.video.sfu.event.ReconnectDetails", init=False) payload: Optional[events_pb2.ReconnectDetails] = field(default=None, repr=False) @@ -1514,50 +1715,52 @@ def strategy(self) -> Optional[int]: """Access strategy field from the protobuf payload. Use models_pb2.WebsocketReconnectStrategy enum.""" if self.payload is None: return None - return getattr(self.payload, 'strategy', None) + return getattr(self.payload, "strategy", None) @property def announced_tracks(self) -> Optional[List[TrackInfo]]: """Access announced_tracks field from the protobuf payload.""" if self.payload is None: return None - proto_list = getattr(self.payload, 'announced_tracks', []) - return [TrackInfo.from_proto(item) for item in proto_list] if proto_list else None + proto_list = getattr(self.payload, "announced_tracks", []) + return ( + [TrackInfo.from_proto(item) for item in proto_list] if proto_list else None + ) @property def subscriptions(self) -> Optional[List[Any]]: """Access subscriptions field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'subscriptions', None) + return getattr(self.payload, "subscriptions", None) @property def reconnect_attempt(self) -> Optional[int]: """Access reconnect_attempt field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'reconnect_attempt', None) + return getattr(self.payload, "reconnect_attempt", None) @property def from_sfu_id(self) -> Optional[str]: """Access from_sfu_id field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'from_sfu_id', None) + return getattr(self.payload, "from_sfu_id", None) @property def previous_session_id(self) -> Optional[str]: """Access previous_session_id field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'previous_session_id', None) + return getattr(self.payload, "previous_session_id", None) @property def reason(self) -> Optional[str]: """Access reason field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'reason', None) + return getattr(self.payload, "reason", None) @classmethod def from_proto(cls, proto_obj: events_pb2.ReconnectDetails, **extra): @@ -1574,11 +1777,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class SfuEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.SfuEvent.""" + type: str = field(default="stream.video.sfu.event.SfuEvent", init=False) payload: Optional[events_pb2.SfuEvent] = field(default=None, repr=False) @@ -1587,35 +1794,35 @@ def subscriber_offer(self) -> Optional[Any]: """Access subscriber_offer field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'subscriber_offer', None) + return getattr(self.payload, "subscriber_offer", None) @property def publisher_answer(self) -> Optional[Any]: """Access publisher_answer field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'publisher_answer', None) + return getattr(self.payload, "publisher_answer", None) @property def connection_quality_changed(self) -> Optional[Any]: """Access connection_quality_changed field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'connection_quality_changed', None) + return getattr(self.payload, "connection_quality_changed", None) @property def audio_level_changed(self) -> Optional[Any]: """Access audio_level_changed field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'audio_level_changed', None) + return getattr(self.payload, "audio_level_changed", None) @property def ice_trickle(self) -> Optional[ICETrickle]: """Access ice_trickle field from the protobuf payload.""" if self.payload is None: return None - proto_val = getattr(self.payload, 'ice_trickle', None) + proto_val = getattr(self.payload, "ice_trickle", None) return ICETrickle.from_proto(proto_val) if proto_val is not None else None @property @@ -1623,126 +1830,126 @@ def change_publish_quality(self) -> Optional[Any]: """Access change_publish_quality field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'change_publish_quality', None) + return getattr(self.payload, "change_publish_quality", None) @property def participant_joined(self) -> Optional[Any]: """Access participant_joined field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'participant_joined', None) + return getattr(self.payload, "participant_joined", None) @property def participant_left(self) -> Optional[Any]: """Access participant_left field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'participant_left', None) + return getattr(self.payload, "participant_left", None) @property def dominant_speaker_changed(self) -> Optional[Any]: """Access dominant_speaker_changed field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'dominant_speaker_changed', None) + return getattr(self.payload, "dominant_speaker_changed", None) @property def join_response(self) -> Optional[Any]: """Access join_response field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'join_response', None) + return getattr(self.payload, "join_response", None) @property def health_check_response(self) -> Optional[Any]: """Access health_check_response field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'health_check_response', None) + return getattr(self.payload, "health_check_response", None) @property def track_published(self) -> Optional[Any]: """Access track_published field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'track_published', None) + return getattr(self.payload, "track_published", None) @property def track_unpublished(self) -> Optional[Any]: """Access track_unpublished field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'track_unpublished', None) + return getattr(self.payload, "track_unpublished", None) @property def error(self) -> Optional[Any]: """Access error field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'error', None) + return getattr(self.payload, "error", None) @property def call_grants_updated(self) -> Optional[Any]: """Access call_grants_updated field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'call_grants_updated', None) + return getattr(self.payload, "call_grants_updated", None) @property def go_away(self) -> Optional[Any]: """Access go_away field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'go_away', None) + return getattr(self.payload, "go_away", None) @property def ice_restart(self) -> Optional[Any]: """Access ice_restart field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'ice_restart', None) + return getattr(self.payload, "ice_restart", None) @property def pins_updated(self) -> Optional[Any]: """Access pins_updated field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'pins_updated', None) + return getattr(self.payload, "pins_updated", None) @property def call_ended(self) -> Optional[Any]: """Access call_ended field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'call_ended', None) + return getattr(self.payload, "call_ended", None) @property def participant_updated(self) -> Optional[Any]: """Access participant_updated field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'participant_updated', None) + return getattr(self.payload, "participant_updated", None) @property def participant_migration_complete(self) -> Optional[Any]: """Access participant_migration_complete field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'participant_migration_complete', None) + return getattr(self.payload, "participant_migration_complete", None) @property def change_publish_options(self) -> Optional[Any]: """Access change_publish_options field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'change_publish_options', None) + return getattr(self.payload, "change_publish_options", None) @property def inbound_state_notification(self) -> Optional[Any]: """Access inbound_state_notification field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'inbound_state_notification', None) + return getattr(self.payload, "inbound_state_notification", None) @classmethod def from_proto(cls, proto_obj: events_pb2.SfuEvent, **extra): @@ -1759,11 +1966,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class SfuRequestEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.SfuRequest.""" + type: str = field(default="stream.video.sfu.event.SfuRequest", init=False) payload: Optional[events_pb2.SfuRequest] = field(default=None, repr=False) @@ -1772,21 +1983,21 @@ def join_request(self) -> Optional[Any]: """Access join_request field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'join_request', None) + return getattr(self.payload, "join_request", None) @property def health_check_request(self) -> Optional[Any]: """Access health_check_request field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'health_check_request', None) + return getattr(self.payload, "health_check_request", None) @property def leave_call_request(self) -> Optional[Any]: """Access leave_call_request field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'leave_call_request', None) + return getattr(self.payload, "leave_call_request", None) @classmethod def from_proto(cls, proto_obj: events_pb2.SfuRequest, **extra): @@ -1803,11 +2014,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class SubscriberOfferEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.SubscriberOffer.""" + type: str = field(default="stream.video.sfu.event.SubscriberOffer", init=False) payload: Optional[events_pb2.SubscriberOffer] = field(default=None, repr=False) @@ -1816,14 +2031,14 @@ def ice_restart(self) -> Optional[bool]: """Access ice_restart field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'ice_restart', None) + return getattr(self.payload, "ice_restart", None) @property def sdp(self) -> Optional[str]: """Access sdp field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'sdp', None) + return getattr(self.payload, "sdp", None) @classmethod def from_proto(cls, proto_obj: events_pb2.SubscriberOffer, **extra): @@ -1840,11 +2055,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class TrackPublishedEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.TrackPublished.""" + type: str = field(default="stream.video.sfu.event.TrackPublished", init=False) payload: Optional[events_pb2.TrackPublished] = field(default=None, repr=False) @@ -1853,14 +2072,14 @@ def user_id(self) -> Optional[str]: """Access user_id field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'user_id', None) + return getattr(self.payload, "user_id", None) @property def participant(self) -> Optional[Participant]: """Access participant field from the protobuf payload.""" if self.payload is None: return None - proto_val = getattr(self.payload, 'participant', None) + proto_val = getattr(self.payload, "participant", None) return Participant.from_proto(proto_val) if proto_val is not None else None @classmethod @@ -1878,11 +2097,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class TrackUnpublishedEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.TrackUnpublished.""" + type: str = field(default="stream.video.sfu.event.TrackUnpublished", init=False) payload: Optional[events_pb2.TrackUnpublished] = field(default=None, repr=False) @@ -1891,21 +2114,21 @@ def user_id(self) -> Optional[str]: """Access user_id field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'user_id', None) + return getattr(self.payload, "user_id", None) @property def cause(self) -> Optional[int]: """Access cause field from the protobuf payload. Use models_pb2.TrackUnpublishReason enum.""" if self.payload is None: return None - return getattr(self.payload, 'cause', None) + return getattr(self.payload, "cause", None) @property def participant(self) -> Optional[Participant]: """Access participant field from the protobuf payload.""" if self.payload is None: return None - proto_val = getattr(self.payload, 'participant', None) + proto_val = getattr(self.payload, "participant", None) return Participant.from_proto(proto_val) if proto_val is not None else None @classmethod @@ -1923,11 +2146,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class VideoLayerSettingEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.VideoLayerSetting.""" + type: str = field(default="stream.video.sfu.event.VideoLayerSetting", init=False) payload: Optional[events_pb2.VideoLayerSetting] = field(default=None, repr=False) @@ -1936,35 +2163,35 @@ def name(self) -> Optional[str]: """Access name field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'name', None) + return getattr(self.payload, "name", None) @property def active(self) -> Optional[bool]: """Access active field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'active', None) + return getattr(self.payload, "active", None) @property def max_bitrate(self) -> Optional[int]: """Access max_bitrate field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'max_bitrate', None) + return getattr(self.payload, "max_bitrate", None) @property def scale_resolution_down_by(self) -> Optional[float]: """Access scale_resolution_down_by field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'scale_resolution_down_by', None) + return getattr(self.payload, "scale_resolution_down_by", None) @property def codec(self) -> Optional[Codec]: """Access codec field from the protobuf payload.""" if self.payload is None: return None - proto_val = getattr(self.payload, 'codec', None) + proto_val = getattr(self.payload, "codec", None) return Codec.from_proto(proto_val) if proto_val is not None else None @property @@ -1972,14 +2199,14 @@ def max_framerate(self) -> Optional[int]: """Access max_framerate field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'max_framerate', None) + return getattr(self.payload, "max_framerate", None) @property def scalability_mode(self) -> Optional[str]: """Access scalability_mode field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'scalability_mode', None) + return getattr(self.payload, "scalability_mode", None) @classmethod def from_proto(cls, proto_obj: events_pb2.VideoLayerSetting, **extra): @@ -1996,11 +2223,15 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) + @dataclass class VideoSenderEvent(BaseEvent): """Dataclass event for video.sfu.event.events_pb2.VideoSender.""" + type: str = field(default="stream.video.sfu.event.VideoSender", init=False) payload: Optional[events_pb2.VideoSender] = field(default=None, repr=False) @@ -2009,7 +2240,7 @@ def codec(self) -> Optional[Codec]: """Access codec field from the protobuf payload.""" if self.payload is None: return None - proto_val = getattr(self.payload, 'codec', None) + proto_val = getattr(self.payload, "codec", None) return Codec.from_proto(proto_val) if proto_val is not None else None @property @@ -2017,21 +2248,21 @@ def layers(self) -> Optional[List[Any]]: """Access layers field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'layers', None) + return getattr(self.payload, "layers", None) @property def track_type(self) -> Optional[int]: """Access track_type field from the protobuf payload. Use models_pb2.TrackType enum.""" if self.payload is None: return None - return getattr(self.payload, 'track_type', None) + return getattr(self.payload, "track_type", None) @property def publish_option_id(self) -> Optional[int]: """Access publish_option_id field from the protobuf payload.""" if self.payload is None: return None - return getattr(self.payload, 'publish_option_id', None) + return getattr(self.payload, "publish_option_id", None) @classmethod def from_proto(cls, proto_obj: events_pb2.VideoSender, **extra): @@ -2048,7 +2279,9 @@ def __getattr__(self, item: str): """Delegate attribute access to protobuf payload.""" if self.payload is not None: return getattr(self.payload, item) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{item}'" + ) __all__ = ( diff --git a/agents-core/vision_agents/core/edge/types.py b/agents-core/vision_agents/core/edge/types.py index 5f68b1a3..fabe7f27 100644 --- a/agents-core/vision_agents/core/edge/types.py +++ b/agents-core/vision_agents/core/edge/types.py @@ -1,4 +1,4 @@ -#from __future__ import annotations +# from __future__ import annotations from dataclasses import dataclass from typing import Any, Optional, NamedTuple import logging @@ -30,6 +30,7 @@ class Connection(AsyncIOEventEmitter): and a way to receive a callback when the call is ended In the future we might want to forward more events """ + async def close(self): pass @@ -107,19 +108,16 @@ def dts_seconds(self) -> Optional[float]: @classmethod def from_bytes( - cls, - audio_bytes: bytes, - sample_rate: int = 16000, - format: str = "s16" + cls, audio_bytes: bytes, sample_rate: int = 16000, format: str = "s16" ) -> "PcmData": """ Create PcmData from raw audio bytes. - + Args: audio_bytes: Raw audio data as bytes sample_rate: Sample rate in Hz format: Audio format (e.g., "s16", "f32") - + Returns: PcmData object """ @@ -129,54 +127,52 @@ def from_bytes( def resample(self, target_sample_rate: int) -> "PcmData": """ Resample PcmData to a different sample rate using AV library. - + Args: target_sample_rate: Target sample rate in Hz - + Returns: New PcmData object with resampled audio """ if self.sample_rate == target_sample_rate: return self - + # Ensure samples are 2D for AV library (channels, samples) samples = self.samples if samples.ndim == 1: # Reshape 1D array to 2D (1 channel, samples) samples = samples.reshape(1, -1) - + # Create AV audio frame from the samples - frame = av.AudioFrame.from_ndarray(samples, format='s16', layout='mono') + frame = av.AudioFrame.from_ndarray(samples, format="s16", layout="mono") frame.sample_rate = self.sample_rate - + # Create resampler resampler = av.AudioResampler( - format='s16', - layout='mono', - rate=target_sample_rate + format="s16", layout="mono", rate=target_sample_rate ) - + # Resample the frame resampled_frames = resampler.resample(frame) if resampled_frames: resampled_frame = resampled_frames[0] resampled_samples = resampled_frame.to_ndarray() - + # AV returns (channels, samples), so for mono we want the first (and only) channel if len(resampled_samples.shape) > 1: # Take the first channel (mono) resampled_samples = resampled_samples[0] - + # Convert to int16 resampled_samples = resampled_samples.astype(np.int16) - + return PcmData( samples=resampled_samples, sample_rate=target_sample_rate, format=self.format, pts=self.pts, dts=self.dts, - time_base=self.time_base + time_base=self.time_base, ) else: # If resampling failed, return original data diff --git a/agents-core/vision_agents/core/events/__init__.py b/agents-core/vision_agents/core/events/__init__.py index 66441fef..3776f914 100644 --- a/agents-core/vision_agents/core/events/__init__.py +++ b/agents-core/vision_agents/core/events/__init__.py @@ -127,5 +127,5 @@ "PluginInitializedEvent", "PluginClosedEvent", "PluginErrorEvent", - "EventManager" + "EventManager", ] diff --git a/agents-core/vision_agents/core/events/base.py b/agents-core/vision_agents/core/events/base.py index 558876c2..020c50ba 100644 --- a/agents-core/vision_agents/core/events/base.py +++ b/agents-core/vision_agents/core/events/base.py @@ -9,6 +9,7 @@ from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import Participant + class ConnectionState(Enum): """Connection states for streaming plugins.""" @@ -32,6 +33,7 @@ class AudioFormat(Enum): @dataclass class BaseEvent(DataClassJsonMixin): """Base class for all events.""" + type: str event_id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @@ -83,11 +85,12 @@ class PluginErrorEvent(PluginBaseEvent): def error_message(self) -> str: return str(self.error) if self.error else "Unknown error" + @dataclasses.dataclass class ExceptionEvent: exc: Exception handler: FunctionType - type: str = 'base.exception' + type: str = "base.exception" @dataclasses.dataclass @@ -95,13 +98,13 @@ class HealthCheckEvent(DataClassJsonMixin): connection_id: str created_at: int custom: dict - type: str = 'health.check' + type: str = "health.check" @dataclass class ConnectionOkEvent(BaseEvent): """Event emitted when WebSocket connection is established.""" - + type: str = field(default="connection.ok", init=False) connection_id: Optional[str] = None server_time: Optional[str] = None @@ -112,7 +115,7 @@ class ConnectionOkEvent(BaseEvent): @dataclass class ConnectionErrorEvent(BaseEvent): """Event emitted when WebSocket connection encounters an error.""" - + type: str = field(default="connection.error", init=False) error_code: Optional[str] = None error_message: Optional[str] = None @@ -122,7 +125,7 @@ class ConnectionErrorEvent(BaseEvent): @dataclass class ConnectionClosedEvent(BaseEvent): """Event emitted when WebSocket connection is closed.""" - + type: str = field(default="connection.closed", init=False) code: Optional[int] = None reason: Optional[str] = None diff --git a/agents-core/vision_agents/core/events/manager.py b/agents-core/vision_agents/core/events/manager.py index 8e9b3669..65389b81 100644 --- a/agents-core/vision_agents/core/events/manager.py +++ b/agents-core/vision_agents/core/events/manager.py @@ -21,47 +21,46 @@ def _truncate_event_for_logging(event, max_length=200): """ Truncate event data for logging to prevent log spam. - + Args: event: The event object to truncate max_length: Maximum length of the string representation - + Returns: Truncated string representation of the event """ event_str = str(event) - + # Special handling for audio data arrays - if hasattr(event, 'pcm_data') and hasattr(event.pcm_data, 'samples'): + if hasattr(event, "pcm_data") and hasattr(event.pcm_data, "samples"): # Replace the full array with a summary samples = event.pcm_data.samples array_summary = f"array([{samples[0]}, {samples[1]}, ..., {samples[-1]}], dtype={samples.dtype}, size={len(samples)})" event_str = event_str.replace(str(samples), array_summary) - + # If the event is still too long, truncate it if len(event_str) > max_length: # Find a good truncation point (end of a field) truncate_at = max_length - 20 # Leave room for "... (truncated)" - while truncate_at > 0 and event_str[truncate_at] not in [',', ')', '}']: + while truncate_at > 0 and event_str[truncate_at] not in [",", ")", "}"]: truncate_at -= 1 - + if truncate_at > 0: event_str = event_str[:truncate_at] + "... (truncated)" else: - event_str = event_str[:max_length-20] + "... (truncated)" - - return event_str + event_str = event_str[: max_length - 20] + "... (truncated)" + return event_str class EventManager: """ A comprehensive event management system for handling asynchronous event-driven communication. - + The EventManager provides a centralized way to register events, subscribe handlers, and process events asynchronously. It supports event queuing, error handling, and automatic exception event generation. - + Features: - Event registration and validation - Handler subscription with type hints @@ -69,42 +68,42 @@ class EventManager: - Error handling with automatic exception events - Support for Union types in handlers - Event queuing and batch processing - + Example: ```python from vision_agents.core.events.manager import EventManager from vision_agents.core.vad.events import VADSpeechStartEvent, VADSpeechEndEvent from vision_agents.core.stt.events import STTTranscriptEvent from vision_agents.core.tts.events import TTSAudioEvent - + # Create event manager manager = EventManager() - + # Register events manager.register(VADSpeechStartEvent) manager.register(VADSpeechEndEvent) manager.register(STTTranscriptEvent) manager.register(TTSAudioEvent) - + # Subscribe to VAD events @manager.subscribe async def handle_speech_start(event: VADSpeechStartEvent): print(f"Speech started with probability {event.speech_probability}") - + @manager.subscribe async def handle_speech_end(event: VADSpeechEndEvent): print(f"Speech ended after {event.total_speech_duration_ms}ms") - + # Subscribe to STT events @manager.subscribe async def handle_transcript(event: STTTranscriptEvent): print(f"Transcript: {event.text} (confidence: {event.confidence})") - + # Subscribe to multiple event types using Union @manager.subscribe async def handle_audio_events(event: VADSpeechStartEvent | VADSpeechEndEvent): print(f"VAD event: {event.type}") - + # Send events manager.send(VADSpeechStartEvent( plugin_name="silero", @@ -116,20 +115,20 @@ async def handle_audio_events(event: VADSpeechStartEvent | VADSpeechEndEvent): text="Hello world", confidence=0.98 )) - + # Before shutdown, ensure all events are processed await manager.shutdown() ``` - + Args: ignore_unknown_events (bool): If True, unknown events are ignored rather than raising errors. Defaults to True. """ - + def __init__(self, ignore_unknown_events: bool = True): """ Initialize the EventManager. - + Args: ignore_unknown_events (bool): If True, unknown events are ignored rather than raising errors. Defaults to True. @@ -149,51 +148,55 @@ def __init__(self, ignore_unknown_events: bool = True): self.register(ConnectionOkEvent) self.register(ConnectionErrorEvent) self.register(ConnectionClosedEvent) - + # Start background processing task self._start_processing_task() def register(self, event_class, ignore_not_compatible=False): """ Register an event class for use with the event manager. - + Event classes must: - Have a name ending with 'Event' - Have a 'type' attribute (string) - + Example: ```python from vision_agents.core.vad.events import VADSpeechStartEvent from vision_agents.core.stt.events import STTTranscriptEvent - + manager = EventManager() manager.register(VADSpeechStartEvent) manager.register(STTTranscriptEvent) ``` - + Args: event_class: The event class to register ignore_not_compatible (bool): If True, log warning instead of raising error for incompatible classes. Defaults to False. - + Raises: ValueError: If event_class doesn't meet requirements and ignore_not_compatible is False """ - if event_class.__name__.endswith('Event') and hasattr(event_class, 'type'): + if event_class.__name__.endswith("Event") and hasattr(event_class, "type"): self._events[event_class.type] = event_class logger.debug(f"Registered new event {event_class} - {event_class.type}") - elif event_class.__name__.endswith('BaseEvent'): + elif event_class.__name__.endswith("BaseEvent"): return elif not ignore_not_compatible: - raise ValueError(f"Provide valid class that ends on '*Event' and 'type' attribute: {event_class}") + raise ValueError( + f"Provide valid class that ends on '*Event' and 'type' attribute: {event_class}" + ) else: - logger.warning(f"Provide valid class that ends on '*Event' and 'type' attribute: {event_class}") + logger.warning( + f"Provide valid class that ends on '*Event' and 'type' attribute: {event_class}" + ) - def merge(self, em: 'EventManager'): + def merge(self, em: "EventManager"): # Stop the processing task in the merged manager if em._processing_task and not em._processing_task.done(): em._processing_task.cancel() - + # Merge all data from the other manager self._events.update(em._events) self._modules.update(em._modules) @@ -211,29 +214,31 @@ def merge(self, em: 'EventManager'): em._silent_events = self._silent_events em._processing_task = None # Clear the stopped task reference - def register_events_from_module(self, module, prefix='', ignore_not_compatible=True): + def register_events_from_module( + self, module, prefix="", ignore_not_compatible=True + ): """ Register all event classes from a module. - + Automatically discovers and registers all classes in a module that: - Have names ending with 'Event' - Have a 'type' attribute (optionally matching the prefix) - + Example: ```python # Register all VAD events from the core module from vision_agents.core import vad manager.register_events_from_module(vad.events, prefix="plugin.vad") - + # Register all TTS events from the core module from vision_agents.core import tts manager.register_events_from_module(tts.events, prefix="plugin.tts") - + # Register all events from a plugin module from vision_agents.plugins.silero import events as silero_events manager.register_events_from_module(silero_events, prefix="plugin.silero") ``` - + Args: module: The Python module to scan for event classes prefix (str): Optional prefix to filter event types. Only events with @@ -242,7 +247,9 @@ def register_events_from_module(self, module, prefix='', ignore_not_compatible=T for incompatible classes. Defaults to True. """ for name, class_ in module.__dict__.items(): - if name.endswith('Event') and (not prefix or getattr(class_, 'type', '').startswith(prefix)): + if name.endswith("Event") and ( + not prefix or getattr(class_, "type", "").startswith(prefix) + ): self.register(class_, ignore_not_compatible=ignore_not_compatible) self._modules.setdefault(module.__name__, []).append(class_) @@ -265,20 +272,20 @@ def _generate_import_file(self): def unsubscribe(self, function): """ Unsubscribe a function from all event types. - + Removes the specified function from all event handler lists. This is useful for cleaning up handlers that are no longer needed. - + Example: ```python @manager.subscribe async def speech_handler(event: VADSpeechStartEvent): print("Speech started") - + # Later, unsubscribe the handler manager.unsubscribe(speech_handler) ``` - + Args: function: The function to unsubscribe from all event types. """ @@ -292,30 +299,30 @@ async def speech_handler(event: VADSpeechStartEvent): def subscribe(self, function): """ Subscribe a function to handle specific event types. - + The function must have type hints indicating which event types it handles. Supports both single event types and Union types for handling multiple event types. - + Example: ```python # Single event type @manager.subscribe async def handle_speech_start(event: VADSpeechStartEvent): print(f"Speech started with probability {event.speech_probability}") - + # Multiple event types using Union @manager.subscribe async def handle_audio_events(event: VADSpeechStartEvent | VADSpeechEndEvent): print(f"VAD event: {event.type}") ``` - + Args: function: The async function to subscribe as an event handler. Must have type hints for event parameters. - + Returns: The decorated function (for use as decorator). - + Raises: RuntimeError: If handler has multiple separate event parameters (use Union instead) KeyError: If event type is not registered and ignore_unknown_events is False @@ -338,44 +345,57 @@ async def handle_audio_events(event: VADSpeechStartEvent | VADSpeechEndEvent): event_type = getattr(sub_event, "type", None) if subscribed and not is_union: - raise RuntimeError("Multiple seperated events per handler are not supported, use Union instead") + raise RuntimeError( + "Multiple seperated events per handler are not supported, use Union instead" + ) if event_type in self._events: subscribed = True self._handlers.setdefault(event_type, []).append(function) - module_name = getattr(function, '__module__', 'unknown') - logger.info(f"Handler {function.__name__} from {module_name} registered for event {event_type}") + module_name = getattr(function, "__module__", "unknown") + logger.info( + f"Handler {function.__name__} from {module_name} registered for event {event_type}" + ) elif not self._ignore_unknown_events: - raise KeyError(f"Event {sub_event} - {event_type} is not registered.") + raise KeyError( + f"Event {sub_event} - {event_type} is not registered." + ) else: - module_name = getattr(function, '__module__', 'unknown') - logger.info(f"Event {sub_event} - {event_type} is not registered – skipping handler {function.__name__} from {module_name}.") + module_name = getattr(function, "__module__", "unknown") + logger.info( + f"Event {sub_event} - {event_type} is not registered – skipping handler {function.__name__} from {module_name}." + ) return function def _prepare_event(self, event): # Handle dict events - convert to event class if isinstance(event, dict): - event_type = event.get('type', '') + event_type = event.get("type", "") try: event_class = self._events[event_type] event = event_class.from_dict(event, infer_missing=True) # type: ignore[attr-defined] except Exception: logger.exception(f"Can't convert dict {event} to event class, skipping") return - + # Handle raw protobuf messages - wrap in BaseEvent subclass # Check for protobuf DESCRIPTOR but exclude already-wrapped BaseEvent subclasses - elif (hasattr(event, 'DESCRIPTOR') and hasattr(event.DESCRIPTOR, 'full_name') - and not hasattr(event, 'event_id')): # event_id is unique to BaseEvent + elif ( + hasattr(event, "DESCRIPTOR") + and hasattr(event.DESCRIPTOR, "full_name") + and not hasattr(event, "event_id") + ): # event_id is unique to BaseEvent proto_type = event.DESCRIPTOR.full_name - + # Look up the registered event class by protobuf type proto_event_class = self._events.get(proto_type) - if proto_event_class and hasattr(proto_event_class, 'from_proto'): + if proto_event_class and hasattr(proto_event_class, "from_proto"): try: event = proto_event_class.from_proto(event) except Exception: - logger.exception(f"Failed to convert protobuf {proto_type} to event class {proto_event_class}") + logger.exception( + f"Failed to convert protobuf {proto_type} to event class {proto_event_class}" + ) return else: # No matching event class found @@ -384,10 +404,10 @@ def _prepare_event(self, event): return else: raise RuntimeError(f"Protobuf event not registered: {proto_type}") - + # Validate event is registered (handles both BaseEvent and generated protobuf events) - if hasattr(event, 'type') and event.type in self._events: - #logger.info(f"Received event {_truncate_event_for_logging(event)}") + if hasattr(event, "type") and event.type in self._events: + # logger.info(f"Received event {_truncate_event_for_logging(event)}") return event elif self._ignore_unknown_events: logger.info(f"Event not registered {_truncate_event_for_logging(event)}") @@ -397,7 +417,7 @@ def _prepare_event(self, event): def silent(self, event_class): """ Silence logging for an event class from being processed. - + Args: event_class: The event class to silence """ @@ -406,11 +426,11 @@ def silent(self, event_class): def send(self, *events): """ Send one or more events for processing. - + Events are added to the queue and will be processed by the background - processing task. If an event handler raises an exception, an ExceptionEvent + processing task. If an event handler raises an exception, an ExceptionEvent is automatically created and queued for processing. - + Example: ```python # Send single event @@ -419,13 +439,13 @@ def send(self, *events): speech_probability=0.95, activation_threshold=0.5 )) - + # Send multiple events manager.send( VADSpeechStartEvent(plugin_name="silero", speech_probability=0.95), STTTranscriptEvent(plugin_name="deepgram", text="Hello world") ) - + # Send event from dictionary manager.send({ "type": "plugin.vad_speech_start", @@ -433,27 +453,27 @@ def send(self, *events): "speech_probability": 0.95 }) ``` - + Args: *events: One or more event objects or dictionaries to send. Events can be instances of registered event classes or dictionaries with a 'type' field that matches a registered event type. - + Raises: RuntimeError: If event type is not registered and ignore_unknown_events is False """ for event in events: event = self._prepare_event(event) if event: - #logger.info(f"🎯 EventManager.send: {event.__class__.__name__} - {event.type}") + # logger.info(f"🎯 EventManager.send: {event.__class__.__name__} - {event.type}") self._queue.append(event) - + async def wait(self, timeout: float = 10.0): """ Wait for all queued events to be processed. - + This is useful in tests to ensure events are processed before assertions. - + Args: timeout: Maximum time to wait for processing to complete """ @@ -463,19 +483,19 @@ async def wait(self, timeout: float = 10.0): if self._handler_tasks: await asyncio.wait(list(self._handler_tasks.values())) - + def _start_processing_task(self): """Start the background event processing task.""" if self._processing_task and not self._processing_task.done(): return # Already running - + loop = asyncio.get_running_loop() self._processing_task = loop.create_task(self._process_events_loop()) async def _process_events_loop(self): """ Background task that continuously processes events from the queue. - + This task runs until shutdown is requested and processes all events in the queue. It's shielded from cancellation to ensure all events are processed before shutdown. @@ -488,12 +508,18 @@ async def _process_events_loop(self): await self._process_single_event(event) except asyncio.CancelledError as exc: cancelled_exc = exc - logger.info(f"Event processing task was cancelled, processing remaining events, {len(self._queue)}") + logger.info( + f"Event processing task was cancelled, processing remaining events, {len(self._queue)}" + ) await self._process_single_event(event) elif cancelled_exc: raise cancelled_exc else: - cleanup_ids = set(task_id for task_id, task in self._handler_tasks.items() if task.done()) + cleanup_ids = set( + task_id + for task_id, task in self._handler_tasks.items() + if task.done() + ) for task_id in cleanup_ids: self._handler_tasks.pop(task_id) await asyncio.sleep(0.0001) @@ -503,17 +529,20 @@ async def _run_handler(self, handler, event): return await handler(event) except Exception as exc: self._queue.appendleft(ExceptionEvent(exc, handler)) # type: ignore[arg-type] - module_name = getattr(handler, '__module__', 'unknown') - logger.exception(f"Error calling handler {handler.__name__} from {module_name} for event {event.type}") + module_name = getattr(handler, "__module__", "unknown") + logger.exception( + f"Error calling handler {handler.__name__} from {module_name} for event {event.type}" + ) async def _process_single_event(self, event): """Process a single event.""" for handler in self._handlers.get(event.type, []): - module_name = getattr(handler, '__module__', 'unknown') + module_name = getattr(handler, "__module__", "unknown") if event.type not in self._silent_events: - logger.info(f"Called handler {handler.__name__} from {module_name} for event {event.type}") + logger.info( + f"Called handler {handler.__name__} from {module_name} for event {event.type}" + ) loop = asyncio.get_running_loop() handler_task = loop.create_task(self._run_handler(handler, event)) self._handler_tasks[uuid.uuid4()] = handler_task - diff --git a/agents-core/vision_agents/core/llm/events.py b/agents-core/vision_agents/core/llm/events.py index e517a479..5ddbb334 100644 --- a/agents-core/vision_agents/core/llm/events.py +++ b/agents-core/vision_agents/core/llm/events.py @@ -7,7 +7,8 @@ @dataclass class RealtimeConnectedEvent(PluginBaseEvent): """Event emitted when realtime connection is established.""" - type: str = field(default='plugin.realtime_connected', init=False) + + type: str = field(default="plugin.realtime_connected", init=False) provider: Optional[str] = None session_config: Optional[dict[str, Any]] = None capabilities: Optional[list[str]] = None @@ -15,7 +16,7 @@ class RealtimeConnectedEvent(PluginBaseEvent): @dataclass class RealtimeDisconnectedEvent(PluginBaseEvent): - type: str = field(default='plugin.realtime_disconnected', init=False) + type: str = field(default="plugin.realtime_disconnected", init=False) provider: Optional[str] = None reason: Optional[str] = None was_clean: bool = True @@ -24,7 +25,8 @@ class RealtimeDisconnectedEvent(PluginBaseEvent): @dataclass class RealtimeAudioInputEvent(PluginBaseEvent): """Event emitted when audio input is sent to realtime session.""" - type: str = field(default='plugin.realtime_audio_input', init=False) + + type: str = field(default="plugin.realtime_audio_input", init=False) audio_data: Optional[bytes] = None audio_format: AudioFormat = AudioFormat.PCM_S16 sample_rate: int = 16000 @@ -34,7 +36,8 @@ class RealtimeAudioInputEvent(PluginBaseEvent): @dataclass class RealtimeAudioOutputEvent(PluginBaseEvent): """Event emitted when audio output is received from realtime session.""" - type: str = field(default='plugin.realtime_audio_output', init=False) + + type: str = field(default="plugin.realtime_audio_output", init=False) audio_data: Optional[bytes] = None audio_format: AudioFormat = AudioFormat.PCM_S16 sample_rate: int = 16000 @@ -45,8 +48,9 @@ class RealtimeAudioOutputEvent(PluginBaseEvent): @dataclass class RealtimeTranscriptEvent(PluginBaseEvent): """Event emitted when realtime session provides a transcript.""" + original: Optional[Any] = None - type: str = field(default='plugin.realtime_transcript', init=False) + type: str = field(default="plugin.realtime_transcript", init=False) text: Optional[str] = None user_metadata: Optional[Any] = None @@ -54,7 +58,7 @@ class RealtimeTranscriptEvent(PluginBaseEvent): @dataclass class RealtimePartialTranscriptEvent(PluginBaseEvent): original: Optional[Any] = None - type: str = field(default='plugin.realtime_partial_transcript', init=False) + type: str = field(default="plugin.realtime_partial_transcript", init=False) text: Optional[str] = None user_metadata: Optional[Any] = None @@ -62,7 +66,8 @@ class RealtimePartialTranscriptEvent(PluginBaseEvent): @dataclass class RealtimeResponseEvent(PluginBaseEvent): """Event emitted when realtime session provides a response.""" - type: str = field(default='plugin.realtime_response', init=False) + + type: str = field(default="plugin.realtime_response", init=False) original: Optional[str] = None text: Optional[str] = None response_id: str = field(default_factory=lambda: str(uuid.uuid4())) @@ -73,7 +78,8 @@ class RealtimeResponseEvent(PluginBaseEvent): @dataclass class RealtimeConversationItemEvent(PluginBaseEvent): """Event emitted for conversation item updates in realtime session.""" - type: str = field(default='plugin.realtime_conversation_item', init=False) + + type: str = field(default="plugin.realtime_conversation_item", init=False) item_id: Optional[str] = None item_type: Optional[str] = ( None # "message", "function_call", "function_call_output" @@ -86,7 +92,8 @@ class RealtimeConversationItemEvent(PluginBaseEvent): @dataclass class RealtimeErrorEvent(PluginBaseEvent): """Event emitted when a realtime error occurs.""" - type: str = field(default='plugin.realtime_error', init=False) + + type: str = field(default="plugin.realtime_error", init=False) error: Optional[Exception] = None error_code: Optional[str] = None context: Optional[str] = None @@ -96,9 +103,10 @@ class RealtimeErrorEvent(PluginBaseEvent): def error_message(self) -> str: return str(self.error) if self.error else "Unknown error" + @dataclass class LLMResponseChunkEvent(PluginBaseEvent): - type: str = field(default='plugin.llm_response_chunk', init=False) + type: str = field(default="plugin.llm_response_chunk", init=False) content_index: int | None = None """The index of the content part that the text delta was added to.""" @@ -118,7 +126,8 @@ class LLMResponseChunkEvent(PluginBaseEvent): @dataclass class LLMResponseCompletedEvent(PluginBaseEvent): """Event emitted after an LLM response is processed.""" - type: str = field(default='plugin.llm_response_completed', init=False) + + type: str = field(default="plugin.llm_response_completed", init=False) original: Any = None text: str = "" @@ -126,7 +135,8 @@ class LLMResponseCompletedEvent(PluginBaseEvent): @dataclass class ToolStartEvent(PluginBaseEvent): """Event emitted when a tool execution starts.""" - type: str = field(default='plugin.llm.tool.start', init=False) + + type: str = field(default="plugin.llm.tool.start", init=False) tool_name: str = "" arguments: Optional[Dict[str, Any]] = None tool_call_id: Optional[str] = None @@ -135,11 +145,11 @@ class ToolStartEvent(PluginBaseEvent): @dataclass class ToolEndEvent(PluginBaseEvent): """Event emitted when a tool execution ends.""" - type: str = field(default='plugin.llm.tool.end', init=False) + + type: str = field(default="plugin.llm.tool.end", init=False) tool_name: str = "" success: bool = True result: Optional[Any] = None error: Optional[str] = None tool_call_id: Optional[str] = None execution_time_ms: Optional[float] = None - diff --git a/agents-core/vision_agents/core/llm/function_registry.py b/agents-core/vision_agents/core/llm/function_registry.py index a35f2995..1f7c187d 100644 --- a/agents-core/vision_agents/core/llm/function_registry.py +++ b/agents-core/vision_agents/core/llm/function_registry.py @@ -15,6 +15,7 @@ @dataclass class FunctionParameter: """Represents a parameter of a function.""" + name: str type: Type description: Optional[str] = None @@ -25,6 +26,7 @@ class FunctionParameter: @dataclass class FunctionDefinition: """Represents a complete function definition.""" + name: str description: str parameters: List[FunctionParameter] @@ -34,77 +36,82 @@ class FunctionDefinition: class FunctionRegistry: """Registry for managing available functions that can be called by LLMs.""" - + def __init__(self): self._functions: Dict[str, FunctionDefinition] = {} - - def register(self, - name: Optional[str] = None, - description: Optional[str] = None) -> Callable: + + def register( + self, name: Optional[str] = None, description: Optional[str] = None + ) -> Callable: """ Decorator to register a function with the registry. - + Args: name: Optional custom name for the function. If not provided, uses the function name. description: Optional description for the function. If not provided, uses the docstring. - + Returns: Decorator function. """ + def decorator(func: Callable) -> Callable: func_name = name or func.__name__ func_description = description or func.__doc__ or "" - + # Extract type hints type_hints = get_type_hints(func) sig = inspect.signature(func) - + parameters = [] for param_name, param in sig.parameters.items(): - if param_name == 'self': + if param_name == "self": continue - + param_type = type_hints.get(param_name, type(None)) param_description = None - + # Check if there's a docstring with parameter descriptions if func.__doc__: # Simple extraction of parameter descriptions from docstring # This is a basic implementation - could be enhanced with proper parsing pass - - parameters.append(FunctionParameter( - name=param_name, - type=param_type, - description=param_description, - required=param.default == inspect.Parameter.empty, - default=param.default if param.default != inspect.Parameter.empty else None - )) - + + parameters.append( + FunctionParameter( + name=param_name, + type=param_type, + description=param_description, + required=param.default == inspect.Parameter.empty, + default=param.default + if param.default != inspect.Parameter.empty + else None, + ) + ) + # Determine return type - return_type = type_hints.get('return', None) - + return_type = type_hints.get("return", None) + function_def = FunctionDefinition( name=func_name, description=func_description, parameters=parameters, function=func, - returns=return_type + returns=return_type, ) - + self._functions[func_name] = function_def return func - + return decorator - + def get_function(self, name: str) -> Optional[FunctionDefinition]: """Get a function definition by name.""" return self._functions.get(name) - + def list_functions(self) -> List[str]: """Get a list of all registered function names.""" return list(self._functions.keys()) - + def get_tool_schemas(self) -> List[ToolSchema]: """Get tool schemas for all registered functions.""" schemas = [] @@ -112,82 +119,81 @@ def get_tool_schemas(self) -> List[ToolSchema]: schema = self._function_to_tool_schema(func_def) schemas.append(schema) return schemas - + def call_function(self, name: str, arguments: Dict[str, Any]) -> Any: """ Call a registered function with the given arguments. - + Args: name: Name of the function to call. arguments: Dictionary of arguments to pass to the function. - + Returns: Result of the function call. - + Raises: KeyError: If the function is not registered. TypeError: If the arguments don't match the function signature. """ if name not in self._functions: raise KeyError(f"Function '{name}' is not registered") - + func_def = self._functions[name] - + # Validate required parameters for param in func_def.parameters: if param.required and param.name not in arguments: - raise TypeError(f"Missing required parameter '{param.name}' for function '{name}'") - + raise TypeError( + f"Missing required parameter '{param.name}' for function '{name}'" + ) + # Call the function with the provided arguments return func_def.function(**arguments) def get_callable(self, name: str) -> Callable: """ Get the callable function by name. - + Args: name: Name of the function - + Returns: The callable function - + Raises: KeyError: If the function is not registered """ if name not in self._functions: raise KeyError(f"Function '{name}' is not registered") - + return self._functions[name].function - + def _function_to_tool_schema(self, func_def: FunctionDefinition) -> ToolSchema: """Convert a function definition to a tool schema.""" properties = {} required = [] - + for param in func_def.parameters: param_schema = self._type_to_json_schema(param.type) if param.description: param_schema["description"] = param.description - + properties[param.name] = param_schema - + if param.required: required.append(param.name) - - schema = { - "type": "object", - "properties": properties - } - + + schema = {"type": "object", "properties": properties} + if required: schema["required"] = required - + return ToolSchema( name=func_def.name, description=func_def.description, - parameters_schema=schema + parameters_schema=schema, ) - + def _type_to_json_schema(self, type_hint: Type) -> Dict[str, Any]: """Convert a Python type hint to a JSON schema.""" # Handle basic types @@ -203,36 +209,30 @@ def _type_to_json_schema(self, type_hint: Type) -> Dict[str, Any]: return {"type": "array"} elif type_hint is dict or type_hint is Dict: return {"type": "object"} - + # Handle Optional types - if hasattr(type_hint, '__origin__') and type_hint.__origin__ is Union: + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: # Check if it's Optional (Union[SomeType, None]) args = type_hint.__args__ if len(args) == 2 and type(None) in args: non_none_type = args[0] if args[1] is type(None) else args[1] return self._type_to_json_schema(non_none_type) - + # Handle List types - if hasattr(type_hint, '__origin__') and type_hint.__origin__ is list: - if hasattr(type_hint, '__args__') and type_hint.__args__: + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is list: + if hasattr(type_hint, "__args__") and type_hint.__args__: item_type = type_hint.__args__[0] - return { - "type": "array", - "items": self._type_to_json_schema(item_type) - } + return {"type": "array", "items": self._type_to_json_schema(item_type)} return {"type": "array"} - + # Handle Dict types - if hasattr(type_hint, '__origin__') and type_hint.__origin__ is dict: + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is dict: return {"type": "object"} - + # Handle Enum types if inspect.isclass(type_hint) and issubclass(type_hint, Enum): - return { - "type": "string", - "enum": [e.value for e in type_hint] - } - + return {"type": "string", "enum": [e.value for e in type_hint]} + # Default fallback return {"type": "string"} diff --git a/agents-core/vision_agents/core/llm/llm.py b/agents-core/vision_agents/core/llm/llm.py index 526bc7b1..9394bf3c 100644 --- a/agents-core/vision_agents/core/llm/llm.py +++ b/agents-core/vision_agents/core/llm/llm.py @@ -3,7 +3,17 @@ import abc import asyncio import json -from typing import Optional, TYPE_CHECKING, Tuple, List, Dict, Any, TypeVar, Callable, Generic +from typing import ( + Optional, + TYPE_CHECKING, + Tuple, + List, + Dict, + Any, + TypeVar, + Callable, + Generic, +) from vision_agents.core.llm import events from vision_agents.core.llm.events import ToolStartEvent, ToolEndEvent @@ -64,7 +74,7 @@ def _build_enhanced_instructions(self) -> Optional[str]: Returns: Enhanced instructions string with markdown file contents included, or None if no parsed instructions """ - if not hasattr(self, 'parsed_instructions') or not self.parsed_instructions: + if not hasattr(self, "parsed_instructions") or not self.parsed_instructions: return None parsed = self.parsed_instructions @@ -79,7 +89,9 @@ def _build_enhanced_instructions(self) -> Optional[str]: enhanced_instructions.append(content) else: enhanced_instructions.append(f"\n### {filename}") - enhanced_instructions.append("*(File not found or could not be read)*") + enhanced_instructions.append( + "*(File not found or could not be read)*" + ) return "\n".join(enhanced_instructions) @@ -87,64 +99,72 @@ def _get_tools_for_provider(self) -> List[Dict[str, Any]]: """ Get tools in provider-specific format. This method should be overridden by each LLM implementation. - + Returns: List of tools in the provider's expected format. """ tools = self.get_available_functions() return self._convert_tools_to_provider_format(tools) - - def _convert_tools_to_provider_format(self, tools: List[ToolSchema]) -> List[Dict[str, Any]]: + + def _convert_tools_to_provider_format( + self, tools: List[ToolSchema] + ) -> List[Dict[str, Any]]: """ Convert ToolSchema objects to provider-specific format. This method should be overridden by each LLM implementation. - + Args: tools: List of ToolSchema objects - + Returns: List of tools in provider-specific format """ # Default implementation - should be overridden return [] - - def _extract_tool_calls_from_response(self, response: Any) -> List[NormalizedToolCallItem]: + + def _extract_tool_calls_from_response( + self, response: Any + ) -> List[NormalizedToolCallItem]: """ Extract tool calls from provider-specific response. This method should be overridden by each LLM implementation. - + Args: response: Provider-specific response object - + Returns: List of normalized tool call items """ # Default implementation - should be overridden return [] - - def _extract_tool_calls_from_stream_chunk(self, chunk: Any) -> List[NormalizedToolCallItem]: + + def _extract_tool_calls_from_stream_chunk( + self, chunk: Any + ) -> List[NormalizedToolCallItem]: """ Extract tool calls from a streaming chunk. This method should be overridden by each LLM implementation. - + Args: chunk: Provider-specific streaming chunk - + Returns: List of normalized tool call items """ # Default implementation - should be overridden return [] - - def _create_tool_result_message(self, tool_calls: List[NormalizedToolCallItem], results: List[Any]) -> List[Dict[str, Any]]: + + def _create_tool_result_message( + self, tool_calls: List[NormalizedToolCallItem], results: List[Any] + ) -> List[Dict[str, Any]]: """ Create tool result messages for the provider. This method should be overridden by each LLM implementation. - + Args: tool_calls: List of tool calls that were executed results: List of results from function execution - + Returns: List of tool result messages in provider format """ @@ -158,64 +178,65 @@ def _attach_agent(self, agent: Agent): self.agent = agent self._conversation = agent.conversation self.instructions = agent.instructions - + # Parse instructions to extract @ mentioned markdown files self.parsed_instructions = parse_instructions(agent.instructions) - def register_function(self, - name: Optional[str] = None, - description: Optional[str] = None) -> Callable: + def register_function( + self, name: Optional[str] = None, description: Optional[str] = None + ) -> Callable: """ Decorator to register a function with the LLM's function registry. - + Args: name: Optional custom name for the function. If not provided, uses the function name. description: Optional description for the function. If not provided, uses the docstring. - + Returns: Decorator function. """ return self.function_registry.register(name, description) - + def get_available_functions(self) -> List[ToolSchema]: """Get a list of available function schemas.""" return self.function_registry.get_tool_schemas() - + def call_function(self, name: str, arguments: Dict[str, Any]) -> Any: """ Call a registered function with the given arguments. - + Args: name: Name of the function to call. arguments: Dictionary of arguments to pass to the function. - + Returns: Result of the function call. """ return self.function_registry.call_function(name, arguments) - def _tc_key(self, tc: Dict[str, Any]) -> Tuple[Optional[str], str, str]: """Generate a unique key for tool call deduplication. - + Args: tc: Tool call dictionary - + Returns: Tuple of (id, name, arguments_json) for deduplication """ return ( - tc.get("id"), - tc["name"], - json.dumps(tc.get("arguments_json", tc.get("arguments", {})), sort_keys=True) + tc.get("id"), + tc["name"], + json.dumps( + tc.get("arguments_json", tc.get("arguments", {})), sort_keys=True + ), ) async def _maybe_await(self, x): """Await if x is a coroutine, otherwise return x directly. - + Args: x: Value that might be a coroutine - + Returns: Awaited result if coroutine, otherwise x """ @@ -225,23 +246,23 @@ async def _maybe_await(self, x): async def _run_one_tool(self, tc: Dict[str, Any], timeout_s: float): """Run a single tool call with timeout. - + Args: tc: Tool call dictionary timeout_s: Timeout in seconds - + Returns: Tuple of (tool_call, result, error) """ import inspect import time - + args = tc.get("arguments_json", tc.get("arguments", {})) or {} start_time = time.time() - + async def _invoke(): # Get the actual function to check if it's async - if hasattr(self.function_registry, 'get_callable'): + if hasattr(self.function_registry, "get_callable"): fn = self.function_registry.get_callable(tc["name"]) if inspect.iscoroutinefunction(fn): return await fn(**args) @@ -252,62 +273,74 @@ async def _invoke(): # Fallback to existing call_function method res = self.call_function(tc["name"], args) return await self._maybe_await(res) - + try: # Send tool start event - self.events.send(ToolStartEvent( - plugin_name="llm", - tool_name=tc["name"], - arguments=args, - tool_call_id=tc.get("id") - )) - + self.events.send( + ToolStartEvent( + plugin_name="llm", + tool_name=tc["name"], + arguments=args, + tool_call_id=tc.get("id"), + ) + ) + res = await asyncio.wait_for(_invoke(), timeout=timeout_s) execution_time = (time.time() - start_time) * 1000 - + # Send tool end event (success) - self.events.send(ToolEndEvent( - plugin_name="llm", - tool_name=tc["name"], - success=True, - result=res, - tool_call_id=tc.get("id"), - execution_time_ms=execution_time - )) - + self.events.send( + ToolEndEvent( + plugin_name="llm", + tool_name=tc["name"], + success=True, + result=res, + tool_call_id=tc.get("id"), + execution_time_ms=execution_time, + ) + ) + return tc, res, None except Exception as e: execution_time = (time.time() - start_time) * 1000 - + # Send tool end event (error) - self.events.send(ToolEndEvent( - plugin_name="llm", - tool_name=tc["name"], - success=False, - error=str(e), - tool_call_id=tc.get("id"), - execution_time_ms=execution_time - )) - + self.events.send( + ToolEndEvent( + plugin_name="llm", + tool_name=tc["name"], + success=False, + error=str(e), + tool_call_id=tc.get("id"), + execution_time_ms=execution_time, + ) + ) + return tc, {"error": str(e)}, e - async def _execute_tools(self, calls: List[Dict[str, Any]], *, max_concurrency: int = 8, timeout_s: float = 30): + async def _execute_tools( + self, + calls: List[Dict[str, Any]], + *, + max_concurrency: int = 8, + timeout_s: float = 30, + ): """Execute multiple tool calls concurrently with timeout. - + Args: calls: List of tool call dictionaries max_concurrency: Maximum number of concurrent tool executions timeout_s: Timeout per tool execution in seconds - + Returns: List of tuples (tool_call, result, error) """ sem = asyncio.Semaphore(max_concurrency) - + async def _guarded(tc): async with sem: return await self._run_one_tool(tc, timeout_s) - + return await asyncio.gather(*[_guarded(tc) for tc in calls]) async def _dedup_and_execute( @@ -319,13 +352,13 @@ async def _dedup_and_execute( seen: Optional[set] = None, ): """De-duplicate (by id/name/args) then execute concurrently. - + Args: calls: List of tool call dictionaries max_concurrency: Maximum number of concurrent tool executions timeout_s: Timeout per tool execution in seconds seen: Set of seen tool call keys for deduplication - + Returns: Tuple of (triples, updated_seen_set) """ @@ -341,16 +374,18 @@ async def _dedup_and_execute( if not to_run: return [], seen # nothing new - triples = await self._execute_tools(to_run, max_concurrency=max_concurrency, timeout_s=timeout_s) + triples = await self._execute_tools( + to_run, max_concurrency=max_concurrency, timeout_s=timeout_s + ) return triples, seen def _sanitize_tool_output(self, value: Any, max_chars: int = 60_000) -> str: """Sanitize tool output to prevent oversized responses. - + Args: value: Tool output value max_chars: Maximum characters allowed - + Returns: Sanitized string output """ diff --git a/agents-core/vision_agents/core/llm/llm_test.py b/agents-core/vision_agents/core/llm/llm_test.py index 9ed92ed7..ac47328f 100644 --- a/agents-core/vision_agents/core/llm/llm_test.py +++ b/agents-core/vision_agents/core/llm/llm_test.py @@ -19,4 +19,3 @@ - STS standardization """ - diff --git a/agents-core/vision_agents/core/llm/realtime.py b/agents-core/vision_agents/core/llm/realtime.py index b109d3d6..dd785ee6 100644 --- a/agents-core/vision_agents/core/llm/realtime.py +++ b/agents-core/vision_agents/core/llm/realtime.py @@ -37,11 +37,12 @@ class Realtime(LLM, abc.ABC): - Transcript outgoing audio """ - fps : int = 1 + + fps: int = 1 def __init__( self, - fps: int = 1, # the number of video frames per second to send (for implementations that support setting fps) + fps: int = 1, # the number of video frames per second to send (for implementations that support setting fps) ): super().__init__() self._is_connected = False @@ -66,7 +67,6 @@ async def connect(self): ... @abc.abstractmethod async def simple_audio_response(self, pcm: PcmData): ... - async def _watch_video_track(self, track: Any, **kwargs) -> None: """Optionally overridden by providers that support video input.""" return None @@ -130,7 +130,9 @@ def _emit_audio_output_event( ) self.events.send(event) - def _emit_partial_transcript_event(self, text: str, user_metadata=None, original=None): + def _emit_partial_transcript_event( + self, text: str, user_metadata=None, original=None + ): event = events.RealtimeTranscriptEvent( text=text, user_metadata=user_metadata, @@ -213,4 +215,3 @@ async def close(self): @abc.abstractmethod async def _close_impl(self): ... - diff --git a/agents-core/vision_agents/core/llm/wrap_method.py b/agents-core/vision_agents/core/llm/wrap_method.py index 995d88b4..6e69b784 100644 --- a/agents-core/vision_agents/core/llm/wrap_method.py +++ b/agents-core/vision_agents/core/llm/wrap_method.py @@ -6,6 +6,7 @@ from typing import Any, Concatenate, ParamSpec, TypeVar import functools + # ---------- The function whose signature we want to reuse ---------- def _native_method( text: str, @@ -15,23 +16,29 @@ def _native_method( ) -> str: return text + # ---------- Typing setup ---------- -P = ParamSpec("P") # will be bound to _echo's parameters -R = TypeVar("R") # will be bound to _echo's return type +P = ParamSpec("P") # will be bound to _echo's parameters +R = TypeVar("R") # will be bound to _echo's return type T = TypeVar("T") # the instance type (self) + # ---------- The decorator factory ---------- -def wrap_native_method(target: Callable[P, R]) -> Callable[ - [Callable[Concatenate[T, P], R]], - Callable[Concatenate[T, P], R] -]: - def decorator(method: Callable[Concatenate[T, P], R]) -> Callable[Concatenate[T, P], R]: +def wrap_native_method( + target: Callable[P, R], +) -> Callable[[Callable[Concatenate[T, P], R]], Callable[Concatenate[T, P], R]]: + def decorator( + method: Callable[Concatenate[T, P], R], + ) -> Callable[Concatenate[T, P], R]: @functools.wraps(method) def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R: return method(self, *args, **kwargs) + return wrapper + return decorator + # ---------- Usage on an instance method ---------- class MyLLM: @wrap_native_method(_native_method) @@ -40,6 +47,7 @@ def native_method(self, *args: P.args, **kwargs: P.kwargs) -> R: # but keeping the signature here lets IDEs show proper hints even before decoration. return _native_method(*args, **kwargs) + # ---------- Example calls (with full typing support propagated) ---------- mc = MyLLM() @@ -49,4 +57,4 @@ def native_method(self, *args: P.args, **kwargs: P.kwargs) -> R: system="assistant", messages=[{"role": "user", "content": "hi"}], max_tokens="42", -) \ No newline at end of file +) diff --git a/agents-core/vision_agents/core/logging_utils.py b/agents-core/vision_agents/core/logging_utils.py index a547f632..2fcdd29d 100644 --- a/agents-core/vision_agents/core/logging_utils.py +++ b/agents-core/vision_agents/core/logging_utils.py @@ -9,6 +9,7 @@ _ORIGINAL_FACTORY = logging.getLogRecordFactory() _CALL_ID_ENABLED = True + @dataclass(slots=True) class CallContextToken: """Token capturing prior state for restoring logging context.""" @@ -61,8 +62,8 @@ def clear_call_context(token: CallContextToken) -> None: global _CURRENT_CALL_ID - #failing TODO: fix - #call_id_ctx.reset(token.context_token) + # failing TODO: fix + # call_id_ctx.reset(token.context_token) _CURRENT_CALL_ID = token.previous_global diff --git a/agents-core/vision_agents/core/mcp/__init__.py b/agents-core/vision_agents/core/mcp/__init__.py index 17a7919d..6ac7ac1c 100644 --- a/agents-core/vision_agents/core/mcp/__init__.py +++ b/agents-core/vision_agents/core/mcp/__init__.py @@ -6,4 +6,10 @@ from .tool_converter import MCPToolConverter from .mcp_manager import MCPManager -__all__ = ["MCPServerRemote", "MCPServerLocal", "MCPBaseServer", "MCPToolConverter", "MCPManager"] +__all__ = [ + "MCPServerRemote", + "MCPServerLocal", + "MCPBaseServer", + "MCPToolConverter", + "MCPManager", +] diff --git a/agents-core/vision_agents/core/mcp/mcp_base.py b/agents-core/vision_agents/core/mcp/mcp_base.py index 5356e7e6..3039cb8f 100644 --- a/agents-core/vision_agents/core/mcp/mcp_base.py +++ b/agents-core/vision_agents/core/mcp/mcp_base.py @@ -9,10 +9,10 @@ class MCPBaseServer(ABC): """Base class for MCP server connections.""" - + def __init__(self, session_timeout: float = 300.0): """Initialize the base MCP server. - + Args: session_timeout: How long an established MCP session can sit idle with no tool calls, no traffic (in seconds) """ @@ -22,33 +22,33 @@ def __init__(self, session_timeout: float = 300.0): self._is_connected = False self._last_activity: Optional[float] = None self._timeout_task: Optional[asyncio.Task] = None - + @abstractmethod async def connect(self) -> None: """Connect to the MCP server.""" pass - + @abstractmethod async def disconnect(self) -> None: """Disconnect from the MCP server.""" pass - + @property def is_connected(self) -> bool: """Check if the server is connected.""" return self._is_connected - + async def _update_activity(self) -> None: """Update the last activity timestamp.""" self._last_activity = asyncio.get_event_loop().time() - + async def _start_timeout_monitor(self) -> None: """Start monitoring for session timeout.""" if self._timeout_task: self._timeout_task.cancel() - + self._timeout_task = asyncio.create_task(self._timeout_monitor()) - + async def _timeout_monitor(self) -> None: """Monitor for session timeout.""" while self._is_connected: @@ -56,23 +56,27 @@ async def _timeout_monitor(self) -> None: if self._last_activity and self._is_connected: idle_time = asyncio.get_event_loop().time() - self._last_activity if idle_time > self.session_timeout: - self.logger.warning(f"Session timeout after {idle_time:.1f}s of inactivity") + self.logger.warning( + f"Session timeout after {idle_time:.1f}s of inactivity" + ) await self.disconnect() break - + async def _stop_timeout_monitor(self) -> None: """Stop the timeout monitor.""" if self._timeout_task: self._timeout_task.cancel() self._timeout_task = None - + async def _ensure_connected(self) -> None: """Ensure the server is connected, reconnecting if necessary.""" if not self._is_connected: self.logger.info("Reconnecting to MCP server...") await self.connect() - - async def _call_with_retry(self, operation_name: str, operation_func, *args, **kwargs): + + async def _call_with_retry( + self, operation_name: str, operation_func, *args, **kwargs + ): """Call an MCP operation with auto-reconnect on failure.""" max_retries = 2 for attempt in range(max_retries + 1): @@ -82,85 +86,104 @@ async def _call_with_retry(self, operation_name: str, operation_func, *args, **k return await operation_func(*args, **kwargs) except Exception as e: if attempt < max_retries: - self.logger.warning(f"{operation_name} failed (attempt {attempt + 1}/{max_retries + 1}): {e}. Reconnecting...") + self.logger.warning( + f"{operation_name} failed (attempt {attempt + 1}/{max_retries + 1}): {e}. Reconnecting..." + ) await self.disconnect() await asyncio.sleep(1) # Brief delay before retry else: - self.logger.error(f"{operation_name} failed after {max_retries + 1} attempts: {e}") + self.logger.error( + f"{operation_name} failed after {max_retries + 1} attempts: {e}" + ) raise - + async def _list_tools_impl(self) -> List[types.Tool]: """Internal implementation of list_tools without retry logic.""" if not self._session or not self._is_connected: raise RuntimeError("Not connected to MCP server") - + await self._update_activity() response = await self._session.list_tools() return response.tools - - async def _call_tool_impl(self, name: str, arguments: Dict[str, Any]) -> types.CallToolResult: + + async def _call_tool_impl( + self, name: str, arguments: Dict[str, Any] + ) -> types.CallToolResult: """Internal implementation of call_tool without retry logic.""" if not self._session or not self._is_connected: raise RuntimeError("Not connected to MCP server") - + await self._update_activity() return await self._session.call_tool(name, arguments) - + async def _list_resources_impl(self) -> List[types.Resource]: """Internal implementation of list_resources without retry logic.""" if not self._session or not self._is_connected: raise RuntimeError("Not connected to MCP server") - + await self._update_activity() response = await self._session.list_resources() return response.resources - + async def _read_resource_impl(self, uri: str) -> types.ReadResourceResult: """Internal implementation of read_resource without retry logic.""" if not self._session or not self._is_connected: raise RuntimeError("Not connected to MCP server") - + await self._update_activity() from mcp.types import AnyUrl + return await self._session.read_resource(AnyUrl(uri)) - + async def _list_prompts_impl(self) -> List[types.Prompt]: """Internal implementation of list_prompts without retry logic.""" if not self._session or not self._is_connected: raise RuntimeError("Not connected to MCP server") - + await self._update_activity() response = await self._session.list_prompts() return response.prompts - - async def _get_prompt_impl(self, name: str, arguments: Dict[str, Any]) -> types.GetPromptResult: + + async def _get_prompt_impl( + self, name: str, arguments: Dict[str, Any] + ) -> types.GetPromptResult: """Internal implementation of get_prompt without retry logic.""" if not self._session or not self._is_connected: raise RuntimeError("Not connected to MCP server") - + await self._update_activity() return await self._session.get_prompt(name, arguments) async def list_tools(self) -> List[types.Tool]: """List available tools from the MCP server with auto-reconnect.""" return await self._call_with_retry("list_tools", self._list_tools_impl) - - async def call_tool(self, name: str, arguments: Dict[str, Any]) -> types.CallToolResult: + + async def call_tool( + self, name: str, arguments: Dict[str, Any] + ) -> types.CallToolResult: """Call a tool on the MCP server with auto-reconnect.""" - return await self._call_with_retry("call_tool", self._call_tool_impl, name, arguments) - + return await self._call_with_retry( + "call_tool", self._call_tool_impl, name, arguments + ) + async def list_resources(self) -> List[types.Resource]: """List available resources from the MCP server with auto-reconnect.""" return await self._call_with_retry("list_resources", self._list_resources_impl) - + async def read_resource(self, uri: str) -> types.ReadResourceResult: """Read a resource from the MCP server with auto-reconnect.""" - return await self._call_with_retry("read_resource", self._read_resource_impl, uri) - + return await self._call_with_retry( + "read_resource", self._read_resource_impl, uri + ) + async def list_prompts(self) -> List[types.Prompt]: """List available prompts from the MCP server with auto-reconnect.""" return await self._call_with_retry("list_prompts", self._list_prompts_impl) - - async def get_prompt(self, name: str, arguments: Dict[str, Any]) -> types.GetPromptResult: + + async def get_prompt( + self, name: str, arguments: Dict[str, Any] + ) -> types.GetPromptResult: """Get a prompt from the MCP server with auto-reconnect.""" - return await self._call_with_retry("get_prompt", self._get_prompt_impl, name, arguments) + return await self._call_with_retry( + "get_prompt", self._get_prompt_impl, name, arguments + ) diff --git a/agents-core/vision_agents/core/mcp/mcp_manager.py b/agents-core/vision_agents/core/mcp/mcp_manager.py index 89caa537..b6e5cb5f 100644 --- a/agents-core/vision_agents/core/mcp/mcp_manager.py +++ b/agents-core/vision_agents/core/mcp/mcp_manager.py @@ -8,10 +8,10 @@ class MCPManager: """Manages MCP server connections and tool registration for agents.""" - + def __init__(self, mcp_servers: List[MCPBaseServer], llm, logger: logging.Logger): """Initialize the MCP manager. - + Args: mcp_servers: List of MCP servers to manage llm: LLM instance for tool registration @@ -20,7 +20,7 @@ def __init__(self, mcp_servers: List[MCPBaseServer], llm, logger: logging.Logger self.mcp_servers = mcp_servers self.llm = llm self.logger = logger - + async def connect_all(self): """Connect to all configured MCP servers and register their tools.""" if not self.mcp_servers: @@ -89,12 +89,12 @@ async def call_tool( self, server_index: int, tool_name: str, arguments: Dict[str, Any] ) -> Any: """Call a tool on a specific MCP server. - + Args: server_index: Index of the MCP server in the mcp_servers list tool_name: Name of the tool to call arguments: Arguments to pass to the tool - + Returns: The result of the tool call """ diff --git a/agents-core/vision_agents/core/mcp/mcp_server_local.py b/agents-core/vision_agents/core/mcp/mcp_server_local.py index 07ae999f..90020f52 100644 --- a/agents-core/vision_agents/core/mcp/mcp_server_local.py +++ b/agents-core/vision_agents/core/mcp/mcp_server_local.py @@ -10,15 +10,15 @@ class MCPServerLocal(MCPBaseServer): """Local MCP server connection using stdio transport.""" - + def __init__( self, command: str, env: Optional[Dict[str, str]] = None, - session_timeout: float = 300.0 + session_timeout: float = 300.0, ): """Initialize the local MCP server connection. - + Args: command: Command to run the MCP server (e.g., "python", "node", etc.) env: Optional environment variables to pass to the server process @@ -31,84 +31,84 @@ def __init__( self._client_context: Optional[object] = None # AsyncGeneratorContextManager self._session_context: Optional[object] = None # ClientSession context manager self._get_session_id_cb: Optional[Callable[[], Optional[str]]] = None - + # Parse command into executable and arguments self._parse_command() - + def _parse_command(self) -> None: """Parse the command string into executable and arguments.""" parts = self.command.split() if not parts: raise ValueError("Command cannot be empty") - + self._executable = parts[0] self._args = parts[1:] if len(parts) > 1 else [] - + async def connect(self) -> None: """Connect to the local MCP server.""" if self._is_connected: self.logger.warning("Already connected to MCP server") return - + try: self.logger.info(f"Connecting to local MCP server: {self.command}") - + # Create server parameters self._server_params = StdioServerParameters( - command=self._executable, - args=self._args, - env=self.env + command=self._executable, args=self._args, env=self.env ) - + # Create the stdio client context self._client_context = stdio_client(self._server_params) # type: ignore[assignment] - + # Enter the context to get the read/write streams # Note: stdio_client only returns (read, write), no session ID callback read, write = await self._client_context.__aenter__() # type: ignore[attr-defined] - + # Create the client session context manager self._session_context = ClientSession(read, write) # type: ignore[assignment] - + # Enter the session context and get the actual session self._session = await self._session_context.__aenter__() # type: ignore[attr-defined] - + # Initialize the connection await self._session.initialize() - + self._is_connected = True await self._update_activity() await self._start_timeout_monitor() - - self.logger.info(f"Successfully connected to local MCP server: {self.command}") - + + self.logger.info( + f"Successfully connected to local MCP server: {self.command}" + ) + except Exception as e: self.logger.error(f"Failed to connect to local MCP server: {e}") # Clean up any partial connection state await self._cleanup_connection() raise - + async def disconnect(self) -> None: """Disconnect from the local MCP server.""" if not self._is_connected: return - + try: self.logger.info("Disconnecting from local MCP server") - + # Stop timeout monitoring await self._stop_timeout_monitor() - + # Clean up the connection await self._cleanup_connection() - + self._is_connected = False self.logger.info("Disconnected from local MCP server") - + except Exception as e: self.logger.error(f"Error disconnecting from local MCP server: {e}") self._is_connected = False - + async def _cleanup_connection(self) -> None: """Clean up the MCP connection resources.""" # Close the session context @@ -118,7 +118,7 @@ async def _cleanup_connection(self) -> None: except Exception as e: self.logger.warning(f"Error closing MCP session context: {e}") self._session_context = None - + # Close the client context if self._client_context: try: @@ -126,20 +126,21 @@ async def _cleanup_connection(self) -> None: except Exception as e: self.logger.warning(f"Error closing MCP client context: {e}") self._client_context = None - + self._session = None self._get_session_id_cb = None - + async def __aenter__(self): """Async context manager entry.""" await self.connect() return self - + async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit.""" await self.disconnect() - def __repr__(self) -> str: """String representation of the local MCP server.""" - return f"MCPServerLocal(command='{self.command}', connected={self._is_connected})" + return ( + f"MCPServerLocal(command='{self.command}', connected={self._is_connected})" + ) diff --git a/agents-core/vision_agents/core/mcp/mcp_server_remote.py b/agents-core/vision_agents/core/mcp/mcp_server_remote.py index cb400504..cf742822 100644 --- a/agents-core/vision_agents/core/mcp/mcp_server_remote.py +++ b/agents-core/vision_agents/core/mcp/mcp_server_remote.py @@ -12,16 +12,16 @@ class MCPServerRemote(MCPBaseServer): """Remote MCP server connection using HTTP Streamable transport.""" - + def __init__( self, url: str, headers: Optional[Dict[str, str]] = None, timeout: float = 30.0, - session_timeout: float = 300.0 + session_timeout: float = 300.0, ): """Initialize the remote MCP server connection. - + Args: url: URL of the MCP server (e.g., "http://localhost:8001/mcp") headers: Optional HTTP headers to include in requests @@ -35,81 +35,89 @@ def __init__( self._client_context: Optional[object] = None # AsyncGeneratorContextManager self._session_context: Optional[object] = None # ClientSession context manager self._get_session_id_cb: Optional[Callable[[], Optional[str]]] = None - + # Validate URL parsed = urlparse(url) if not parsed.scheme or not parsed.netloc: raise ValueError(f"Invalid URL: {url}") - + async def connect(self) -> None: """Connect to the remote MCP server.""" if self._is_connected: self.logger.warning("Already connected to MCP server") return - + try: self.logger.info(f"Connecting to remote MCP server at {self.url}") - + # Create the HTTP client context self._client_context = streamablehttp_client( # type: ignore[assignment] - self.url, - headers=self.headers, - timeout=timedelta(seconds=self.timeout) + self.url, headers=self.headers, timeout=timedelta(seconds=self.timeout) ) - + # Enter the context to get the read/write streams and session ID callback - read, write, self._get_session_id_cb = await self._client_context.__aenter__() # type: ignore[attr-defined] - + ( + read, + write, + self._get_session_id_cb, + ) = await self._client_context.__aenter__() # type: ignore[attr-defined] + # Create the client session context manager self._session_context = ClientSession(read, write) # type: ignore[assignment] - + # Enter the session context and get the actual session self._session = await self._session_context.__aenter__() # type: ignore[attr-defined] - + # Initialize the connection await self._session.initialize() - + self._is_connected = True await self._update_activity() await self._start_timeout_monitor() - + # Log session ID if available if self._get_session_id_cb is not None: try: session_id = self._get_session_id_cb() - self.logger.info(f"Successfully connected to remote MCP server at {self.url} (session: {session_id})") + self.logger.info( + f"Successfully connected to remote MCP server at {self.url} (session: {session_id})" + ) except Exception as e: - self.logger.info(f"Successfully connected to remote MCP server at {self.url} (session ID unavailable: {e})") + self.logger.info( + f"Successfully connected to remote MCP server at {self.url} (session ID unavailable: {e})" + ) else: - self.logger.info(f"Successfully connected to remote MCP server at {self.url}") - + self.logger.info( + f"Successfully connected to remote MCP server at {self.url}" + ) + except Exception as e: self.logger.error(f"Failed to connect to remote MCP server: {e}") # Clean up any partial connection state await self._cleanup_connection() raise - + async def disconnect(self) -> None: """Disconnect from the remote MCP server.""" if not self._is_connected: return - + try: self.logger.info("Disconnecting from remote MCP server") - + # Stop timeout monitoring await self._stop_timeout_monitor() - + # Clean up the connection await self._cleanup_connection() - + self._is_connected = False self.logger.info("Disconnected from remote MCP server") - + except Exception as e: self.logger.error(f"Error disconnecting from remote MCP server: {e}") self._is_connected = False - + async def _cleanup_connection(self) -> None: """Clean up the MCP connection resources.""" # Close the session context @@ -119,7 +127,7 @@ async def _cleanup_connection(self) -> None: except Exception as e: self.logger.warning(f"Error closing MCP session context: {e}") self._session_context = None - + # Close the client context if self._client_context: try: @@ -127,19 +135,18 @@ async def _cleanup_connection(self) -> None: except Exception as e: self.logger.warning(f"Error closing MCP client context: {e}") self._client_context = None - + self._session = None self._get_session_id_cb = None - + async def __aenter__(self): """Async context manager entry.""" await self.connect() return self - + async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit.""" await self.disconnect() - def __repr__(self) -> str: """String representation of the remote MCP server.""" diff --git a/agents-core/vision_agents/core/mcp/tool_converter.py b/agents-core/vision_agents/core/mcp/tool_converter.py index 5d8b0ba6..f4ea98d9 100644 --- a/agents-core/vision_agents/core/mcp/tool_converter.py +++ b/agents-core/vision_agents/core/mcp/tool_converter.py @@ -8,78 +8,81 @@ class MCPToolConverter: """Converts MCP tools to function registry format.""" - + @staticmethod def mcp_tool_to_tool_schema(tool: types.Tool) -> ToolSchema: """Convert an MCP tool to a ToolSchema. - + Args: tool: MCP tool object - + Returns: ToolSchema compatible with function registry """ # Convert MCP tool input schema to JSON schema format parameters_schema = MCPToolConverter._convert_input_schema(tool.inputSchema) - + return ToolSchema( name=tool.name, description=tool.description or "", - parameters_schema=parameters_schema + parameters_schema=parameters_schema, ) - + @staticmethod def _convert_input_schema(input_schema: Dict[str, Any]) -> Dict[str, Any]: """Convert MCP input schema to JSON schema format. - + Args: input_schema: MCP tool input schema - + Returns: JSON schema compatible with function registry """ # MCP tools already use JSON schema format, so we can mostly pass through # but we need to ensure it has the right structure schema = input_schema.copy() - + # Ensure required fields are present if "type" not in schema: schema["type"] = "object" - + if "properties" not in schema: schema["properties"] = {} - + # Ensure additionalProperties is set if "additionalProperties" not in schema: schema["additionalProperties"] = False - + return schema - + @staticmethod - def create_mcp_tool_wrapper(server_index: int, tool_name: str, agent_ref) -> "Callable": + def create_mcp_tool_wrapper( + server_index: int, tool_name: str, agent_ref + ) -> "Callable": """Create a wrapper function for calling MCP tools. - + Args: server_index: Index of the MCP server in the agent's mcp_servers list tool_name: Name of the tool to call agent_ref: Reference to the agent instance - + Returns: Callable function that can be registered with the function registry """ + async def mcp_tool_wrapper(**kwargs) -> Any: """Wrapper function for MCP tool calls.""" try: result = await agent_ref.call_tool(server_index, tool_name, kwargs) # Extract the actual result from MCP response - if hasattr(result, 'content') and result.content: + if hasattr(result, "content") and result.content: # MCP tools return CallToolResult with content if isinstance(result.content, list) and len(result.content) > 0: # Get the first content item content_item = result.content[0] - if hasattr(content_item, 'text'): + if hasattr(content_item, "text"): return content_item.text - elif hasattr(content_item, 'data'): + elif hasattr(content_item, "data"): return content_item.data else: return str(content_item) @@ -89,6 +92,10 @@ async def mcp_tool_wrapper(**kwargs) -> Any: return str(result) except Exception as e: # Return error information in a structured way - return {"error": str(e), "tool": tool_name, "server_index": server_index} - + return { + "error": str(e), + "tool": tool_name, + "server_index": server_index, + } + return mcp_tool_wrapper diff --git a/agents-core/vision_agents/core/stt/events.py b/agents-core/vision_agents/core/stt/events.py index 8f016960..4f79eeb8 100644 --- a/agents-core/vision_agents/core/stt/events.py +++ b/agents-core/vision_agents/core/stt/events.py @@ -7,7 +7,7 @@ class STTTranscriptEvent(PluginBaseEvent): """Event emitted when a complete transcript is available.""" - type: str = field(default='plugin.stt_transcript', init=False) + type: str = field(default="plugin.stt_transcript", init=False) text: str = "" confidence: Optional[float] = None language: Optional[str] = None @@ -26,7 +26,7 @@ def __post_init__(self): class STTPartialTranscriptEvent(PluginBaseEvent): """Event emitted when a partial transcript is available.""" - type: str = field(default='plugin.stt_partial_transcript', init=False) + type: str = field(default="plugin.stt_partial_transcript", init=False) text: str = "" confidence: Optional[float] = None language: Optional[str] = None @@ -41,7 +41,7 @@ class STTPartialTranscriptEvent(PluginBaseEvent): class STTErrorEvent(PluginBaseEvent): """Event emitted when an STT error occurs.""" - type: str = field(default='plugin.stt_error', init=False) + type: str = field(default="plugin.stt_error", init=False) error: Optional[Exception] = None error_code: Optional[str] = None context: Optional[str] = None @@ -57,7 +57,7 @@ def error_message(self) -> str: class STTConnectionEvent(PluginBaseEvent): """Event emitted for STT connection state changes.""" - type: str = field(default='plugin.stt_connection', init=False) + type: str = field(default="plugin.stt_connection", init=False) connection_state: Optional[ConnectionState] = None provider: Optional[str] = None details: Optional[dict[str, Any]] = None diff --git a/agents-core/vision_agents/core/stt/stt.py b/agents-core/vision_agents/core/stt/stt.py index 18f10714..d5fbc686 100644 --- a/agents-core/vision_agents/core/stt/stt.py +++ b/agents-core/vision_agents/core/stt/stt.py @@ -54,7 +54,7 @@ def __init__( Args: sample_rate: The sample rate of the audio to process, in Hz. provider_name: Name of the STT provider (e.g., "deepgram", "moonshine") - """ + """ self._track = None self.sample_rate = sample_rate @@ -64,13 +64,15 @@ def __init__( self.events = EventManager() self.events.register_events_from_module(events, ignore_not_compatible=True) - self.events.send(PluginInitializedEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - plugin_type="STT", - provider=self.provider_name, - configuration={"sample_rate": sample_rate}, - )) + self.events.send( + PluginInitializedEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + plugin_type="STT", + provider=self.provider_name, + configuration={"sample_rate": sample_rate}, + ) + ) def _validate_pcm_data(self, pcm_data: PcmData) -> bool: """ @@ -112,18 +114,20 @@ def _emit_transcript_event( user_metadata: User-specific metadata. metadata: Transcription metadata (processing time, confidence, etc.). """ - self.events.send(events.STTTranscriptEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - text=text, - user_metadata=user_metadata, - confidence=metadata.get("confidence"), - language=metadata.get("language"), - processing_time_ms=metadata.get("processing_time_ms"), - audio_duration_ms=metadata.get("audio_duration_ms"), - model_name=metadata.get("model_name"), - words=metadata.get("words"), - )) + self.events.send( + events.STTTranscriptEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + text=text, + user_metadata=user_metadata, + confidence=metadata.get("confidence"), + language=metadata.get("language"), + processing_time_ms=metadata.get("processing_time_ms"), + audio_duration_ms=metadata.get("audio_duration_ms"), + model_name=metadata.get("model_name"), + words=metadata.get("words"), + ) + ) def _emit_partial_transcript_event( self, @@ -139,18 +143,20 @@ def _emit_partial_transcript_event( user_metadata: User-specific metadata. metadata: Transcription metadata (processing time, confidence, etc.). """ - self.events.send(events.STTPartialTranscriptEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - text=text, - user_metadata=user_metadata, - confidence=metadata.get("confidence"), - language=metadata.get("language"), - processing_time_ms=metadata.get("processing_time_ms"), - audio_duration_ms=metadata.get("audio_duration_ms"), - model_name=metadata.get("model_name"), - words=metadata.get("words"), - )) + self.events.send( + events.STTPartialTranscriptEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + text=text, + user_metadata=user_metadata, + confidence=metadata.get("confidence"), + language=metadata.get("language"), + processing_time_ms=metadata.get("processing_time_ms"), + audio_duration_ms=metadata.get("audio_duration_ms"), + model_name=metadata.get("model_name"), + words=metadata.get("words"), + ) + ) def _emit_error_event( self, @@ -166,15 +172,17 @@ def _emit_error_event( context: Additional context about where the error occurred. user_metadata: User-specific metadata. """ - self.events.send(events.STTErrorEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - error=error, - context=context, - user_metadata=user_metadata, - error_code=getattr(error, "error_code", None), - is_recoverable=not isinstance(error, (SystemExit, KeyboardInterrupt)), - )) + self.events.send( + events.STTErrorEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + error=error, + context=context, + user_metadata=user_metadata, + error_code=getattr(error, "error_code", None), + is_recoverable=not isinstance(error, (SystemExit, KeyboardInterrupt)), + ) + ) async def process_audio( self, pcm_data: PcmData, participant: Optional[Participant] = None @@ -240,7 +248,9 @@ async def process_audio( @abc.abstractmethod async def _process_audio_impl( - self, pcm_data: PcmData, user_metadata: Optional[Union[Dict[str, Any], Participant]] = None + self, + pcm_data: PcmData, + user_metadata: Optional[Union[Dict[str, Any], Participant]] = None, ) -> Optional[List[Tuple[bool, str, Dict[str, Any]]]]: """ Implementation-specific method to process audio data. @@ -280,10 +290,12 @@ async def close(self): self._is_closed = True # Emit closure event - self.events.send(PluginClosedEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - plugin_type="STT", - provider=self.provider_name, - cleanup_successful=True, - )) + self.events.send( + PluginClosedEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + plugin_type="STT", + provider=self.provider_name, + cleanup_successful=True, + ) + ) diff --git a/agents-core/vision_agents/core/tts/events.py b/agents-core/vision_agents/core/tts/events.py index 88197488..989ea8aa 100644 --- a/agents-core/vision_agents/core/tts/events.py +++ b/agents-core/vision_agents/core/tts/events.py @@ -1,7 +1,5 @@ import uuid -from vision_agents.core.events import ( - PluginBaseEvent, AudioFormat, ConnectionState -) +from vision_agents.core.events import PluginBaseEvent, AudioFormat, ConnectionState from dataclasses import dataclass, field from typing import Optional, Any @@ -72,4 +70,3 @@ class TTSConnectionEvent(PluginBaseEvent): connection_state: Optional[ConnectionState] = None provider: Optional[str] = None details: Optional[dict[str, Any]] = None - diff --git a/agents-core/vision_agents/core/tts/tts.py b/agents-core/vision_agents/core/tts/tts.py index 5653fa8c..8c11db2d 100644 --- a/agents-core/vision_agents/core/tts/tts.py +++ b/agents-core/vision_agents/core/tts/tts.py @@ -52,12 +52,14 @@ def __init__(self, provider_name: Optional[str] = None): self.provider_name = provider_name or self.__class__.__name__ self.events = EventManager() self.events.register_events_from_module(events, ignore_not_compatible=True) - self.events.send(PluginInitializedEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - plugin_type="TTS", - provider=self.provider_name, - )) + self.events.send( + PluginInitializedEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + plugin_type="TTS", + provider=self.provider_name, + ) + ) def set_output_track(self, track: AudioStreamTrack) -> None: """ @@ -76,10 +78,10 @@ def track(self): def get_required_framerate(self) -> int: """ Get the required framerate for the audio track. - + This method should be overridden by subclasses to return their specific framerate requirement. Defaults to 16000 Hz. - + Returns: The required framerate in Hz """ @@ -88,10 +90,10 @@ def get_required_framerate(self) -> int: def get_required_stereo(self) -> bool: """ Get whether the audio track should be stereo or mono. - + This method should be overridden by subclasses to return their specific stereo requirement. Defaults to False (mono). - + Returns: True if stereo is required, False for mono """ @@ -157,13 +159,15 @@ async def send( "Starting text-to-speech synthesis", extra={"text_length": len(text)} ) - self.events.send(TTSSynthesisStartEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - text=text, - synthesis_id=synthesis_id, - user_metadata=user, - )) + self.events.send( + TTSSynthesisStartEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + text=text, + synthesis_id=synthesis_id, + user_metadata=user, + ) + ) # Synthesize audio audio_data = await self.stream_audio(text, *args, **kwargs) @@ -198,33 +202,41 @@ async def send( await self._track.write(chunk) # Emit structured audio event - self.events.send(TTSAudioEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - audio_data=chunk, - synthesis_id=synthesis_id, - text_source=text, - user_metadata=user, - chunk_index=audio_chunks - 1, - is_final_chunk=False, # We don't know if it's final yet - sample_rate=self._track.framerate if self._track else 16000, - )) + self.events.send( + TTSAudioEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + audio_data=chunk, + synthesis_id=synthesis_id, + text_source=text, + user_metadata=user, + chunk_index=audio_chunks - 1, + is_final_chunk=False, # We don't know if it's final yet + sample_rate=self._track.framerate + if self._track + else 16000, + ) + ) else: # assume it's a Cartesia TTS chunk object total_audio_bytes += len(chunk.data) audio_chunks += 1 await self._track.write(chunk.data) - self.events.send(TTSAudioEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - audio_data=chunk.data, - synthesis_id=synthesis_id, - text_source=text, - user_metadata=user, - chunk_index=audio_chunks - 1, - is_final_chunk=False, # We don't know if it's final yet - sample_rate=self._track.framerate if self._track else 16000, - )) + self.events.send( + TTSAudioEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + audio_data=chunk.data, + synthesis_id=synthesis_id, + text_source=text, + user_metadata=user, + chunk_index=audio_chunks - 1, + is_final_chunk=False, # We don't know if it's final yet + sample_rate=self._track.framerate + if self._track + else 16000, + ) + ) elif hasattr(audio_data, "__iter__") and not isinstance( audio_data, (str, bytes, bytearray) ): @@ -233,17 +245,19 @@ async def send( audio_chunks += 1 await self._track.write(chunk) - self.events.send(TTSAudioEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - audio_data=chunk, - synthesis_id=synthesis_id, - text_source=text, - user_metadata=user, - chunk_index=audio_chunks - 1, - is_final_chunk=False, # We don't know if it's final yet - sample_rate=self._track.framerate if self._track else 16000, - )) + self.events.send( + TTSAudioEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + audio_data=chunk, + synthesis_id=synthesis_id, + text_source=text, + user_metadata=user, + chunk_index=audio_chunks - 1, + is_final_chunk=False, # We don't know if it's final yet + sample_rate=self._track.framerate if self._track else 16000, + ) + ) else: raise TypeError( f"Unsupported return type from synthesize: {type(audio_data)}" @@ -261,38 +275,44 @@ async def send( else None ) - self.events.send(TTSSynthesisCompleteEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - synthesis_id=synthesis_id, - text=text, - user_metadata=user, - total_audio_bytes=total_audio_bytes, - synthesis_time_ms=synthesis_time * 1000, - audio_duration_ms=estimated_audio_duration_ms, - chunk_count=audio_chunks, - real_time_factor=real_time_factor, - )) + self.events.send( + TTSSynthesisCompleteEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + synthesis_id=synthesis_id, + text=text, + user_metadata=user, + total_audio_bytes=total_audio_bytes, + synthesis_time_ms=synthesis_time * 1000, + audio_duration_ms=estimated_audio_duration_ms, + chunk_count=audio_chunks, + real_time_factor=real_time_factor, + ) + ) except Exception as e: - self.events.send(TTSErrorEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - error=e, - context="synthesis", - text_source=text, - synthesis_id=synthesis_id, - user_metadata=user, - )) + self.events.send( + TTSErrorEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + error=e, + context="synthesis", + text_source=text, + synthesis_id=synthesis_id, + user_metadata=user, + ) + ) # ASK: why ? # Re-raise to allow the caller to handle the error raise async def close(self): """Close the TTS service and release any resources.""" - self.events.send(PluginClosedEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - plugin_type="TTS", - provider=self.provider_name, - cleanup_successful=True, - )) + self.events.send( + PluginClosedEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + plugin_type="TTS", + provider=self.provider_name, + cleanup_successful=True, + ) + ) diff --git a/agents-core/vision_agents/core/turn_detection/events.py b/agents-core/vision_agents/core/turn_detection/events.py index c09059c9..d172a04b 100644 --- a/agents-core/vision_agents/core/turn_detection/events.py +++ b/agents-core/vision_agents/core/turn_detection/events.py @@ -10,14 +10,14 @@ class TurnStartedEvent(PluginBaseEvent): """ Event emitted when a speaker starts their turn. - + Attributes: speaker_id: ID of the speaker who started speaking confidence: Confidence level of the turn detection (0.0-1.0) duration: Duration of audio processed (seconds) custom: Additional metadata specific to the turn detection implementation """ - + type: str = field(default="plugin.turn_started", init=False) speaker_id: Optional[str] = None confidence: Optional[float] = None @@ -29,14 +29,14 @@ class TurnStartedEvent(PluginBaseEvent): class TurnEndedEvent(PluginBaseEvent): """ Event emitted when a speaker completes their turn. - + Attributes: speaker_id: ID of the speaker who finished speaking confidence: Confidence level of the turn completion detection (0.0-1.0) duration: Duration of the turn (seconds) custom: Additional metadata specific to the turn detection implementation """ - + type: str = field(default="plugin.turn_ended", init=False) speaker_id: Optional[str] = None confidence: Optional[float] = None @@ -45,4 +45,3 @@ class TurnEndedEvent(PluginBaseEvent): __all__ = ["TurnStartedEvent", "TurnEndedEvent"] - diff --git a/agents-core/vision_agents/core/turn_detection/turn_detection.py b/agents-core/vision_agents/core/turn_detection/turn_detection.py index 40f783d7..248e1e8d 100644 --- a/agents-core/vision_agents/core/turn_detection/turn_detection.py +++ b/agents-core/vision_agents/core/turn_detection/turn_detection.py @@ -74,9 +74,7 @@ class TurnDetector(ABC): """Base implementation for turn detection with common functionality.""" def __init__( - self, - confidence_threshold: float = 0.5, - provider_name: Optional[str] = None + self, confidence_threshold: float = 0.5, provider_name: Optional[str] = None ) -> None: self._confidence_threshold = confidence_threshold self._is_detecting = False @@ -84,12 +82,14 @@ def __init__( self.provider_name = provider_name or self.__class__.__name__ self.events = EventManager() self.events.register_events_from_module(events, ignore_not_compatible=True) - self.events.send(PluginInitializedEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - plugin_type="TurnDetection", - provider=self.provider_name, - )) + self.events.send( + PluginInitializedEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + plugin_type="TurnDetection", + provider=self.provider_name, + ) + ) @abstractmethod def is_detecting(self) -> bool: @@ -101,29 +101,33 @@ def _emit_turn_event( ) -> None: """ Emit a turn detection event using the new event system. - + Args: event_type: The type of turn event (TURN_STARTED or TURN_ENDED) event_data: Data associated with the event """ if event_type == TurnEvent.TURN_STARTED: - self.events.send(events.TurnStartedEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - speaker_id=event_data.speaker_id, - confidence=event_data.confidence, - duration=event_data.duration, - custom=event_data.custom, - )) + self.events.send( + events.TurnStartedEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + speaker_id=event_data.speaker_id, + confidence=event_data.confidence, + duration=event_data.duration, + custom=event_data.custom, + ) + ) elif event_type == TurnEvent.TURN_ENDED: - self.events.send(events.TurnEndedEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - speaker_id=event_data.speaker_id, - confidence=event_data.confidence, - duration=event_data.duration, - custom=event_data.custom, - )) + self.events.send( + events.TurnEndedEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + speaker_id=event_data.speaker_id, + confidence=event_data.confidence, + duration=event_data.duration, + custom=event_data.custom, + ) + ) @abstractmethod async def process_audio( diff --git a/agents-core/vision_agents/core/utils/__init__.py b/agents-core/vision_agents/core/utils/__init__.py index 60cb0c43..3c06cadd 100644 --- a/agents-core/vision_agents/core/utils/__init__.py +++ b/agents-core/vision_agents/core/utils/__init__.py @@ -11,5 +11,3 @@ logger = logging.getLogger(__name__) __all__ = ["get_vision_agents_version"] - - diff --git a/agents-core/vision_agents/core/utils/queue.py b/agents-core/vision_agents/core/utils/queue.py index bdc10019..841964ea 100644 --- a/agents-core/vision_agents/core/utils/queue.py +++ b/agents-core/vision_agents/core/utils/queue.py @@ -3,11 +3,13 @@ T = TypeVar("T") + class LatestNQueue(asyncio.Queue, Generic[T]): """ A generic asyncio queue that always keeps only the latest N items. If full on put, it discards oldest items to make room. """ + def __init__(self, maxlen: int): super().__init__(maxsize=maxlen) @@ -25,4 +27,4 @@ def put_latest_nowait(self, item: T) -> None: self.get_nowait() except asyncio.QueueEmpty: break - super().put_nowait(item) \ No newline at end of file + super().put_nowait(item) diff --git a/agents-core/vision_agents/core/utils/utils.py b/agents-core/vision_agents/core/utils/utils.py index 2ef6e558..c90caa0b 100644 --- a/agents-core/vision_agents/core/utils/utils.py +++ b/agents-core/vision_agents/core/utils/utils.py @@ -16,6 +16,7 @@ @dataclass class Instructions: """Container for parsed instructions with input text and markdown files.""" + input_text: str markdown_contents: MarkdownFileContents # Maps filename to file content base_dir: str = "" # Base directory for file search, defaults to empty string @@ -36,19 +37,17 @@ def to_mono(samples: np.ndarray, num_channels: int) -> np.ndarray: return np.asarray(mono_samples, dtype=np.int16) - - def parse_instructions(text: str, base_dir: Optional[str] = None) -> Instructions: """ Parse instructions from a string, extracting @ mentioned markdown files and their contents. - + Args: text: Input text that may contain @ mentions of markdown files base_dir: Base directory to search for markdown files. If None, uses current working directory. - + Returns: Instructions object containing the input text and file contents - + Example: >>> text = "Please read @file1.md and @file2.md for context" >>> result = parse_instructions(text) @@ -59,22 +58,22 @@ def parse_instructions(text: str, base_dir: Optional[str] = None) -> Instruction """ # Find all @ mentions that look like markdown files # Pattern matches @ followed by filename with .md extension - markdown_pattern = r'@([^\s@]+\.md)' + markdown_pattern = r"@([^\s@]+\.md)" matches = re.findall(markdown_pattern, text) - + # Create a dictionary mapping filename to file content markdown_contents = {} - + # Set base directory for file search if base_dir is None: base_dir = os.getcwd() - + for match in matches: # Try to read the markdown file content file_path = os.path.join(base_dir, match) try: if os.path.isfile(file_path): - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: markdown_contents[match] = f.read() else: # File not found, store empty string @@ -82,21 +81,19 @@ def parse_instructions(text: str, base_dir: Optional[str] = None) -> Instruction except (OSError, IOError, UnicodeDecodeError): # File read error, store empty string markdown_contents[match] = "" - + return Instructions( - input_text=text, - markdown_contents=markdown_contents, - base_dir=base_dir + input_text=text, markdown_contents=markdown_contents, base_dir=base_dir ) def frame_to_png_bytes(frame) -> bytes: """ Convert a video frame to PNG bytes. - + Args: frame: Video frame object that can be converted to an image - + Returns: PNG bytes of the frame, or empty bytes if conversion fails """ @@ -107,7 +104,7 @@ def frame_to_png_bytes(frame) -> bytes: else: arr = frame.to_ndarray(format="rgb24") img = Image.fromarray(arr) - + buf = io.BytesIO() img.save(buf, format="PNG") return buf.getvalue() @@ -119,7 +116,7 @@ def frame_to_png_bytes(frame) -> bytes: def get_vision_agents_version() -> str: """ Get the installed vision-agents package version. - + Returns: Version string, or "unknown" if not available. """ @@ -127,4 +124,3 @@ def get_vision_agents_version() -> str: return importlib.metadata.version("vision-agents") except importlib.metadata.PackageNotFoundError: return "unknown" - diff --git a/agents-core/vision_agents/core/utils/video_forwarder.py b/agents-core/vision_agents/core/utils/video_forwarder.py index 9962f028..ef9cfa7d 100644 --- a/agents-core/vision_agents/core/utils/video_forwarder.py +++ b/agents-core/vision_agents/core/utils/video_forwarder.py @@ -10,6 +10,7 @@ logger = logging.getLogger(__name__) + class VideoForwarder: """ Pulls frames from `input_track` into a latest-N buffer. @@ -18,7 +19,15 @@ class VideoForwarder: - run `start_event_consumer(on_frame)` (push model via callback). `fps` limits how often frames are forwarded to consumers (coalescing to newest). """ - def __init__(self, input_track: VideoStreamTrack, *, max_buffer: int = 10, fps: Optional[float] = 30, name: str = "video-forwarder"): + + def __init__( + self, + input_track: VideoStreamTrack, + *, + max_buffer: int = 10, + fps: Optional[float] = 30, + name: str = "video-forwarder", + ): self.input_track = input_track self.queue: LatestNQueue[Frame] = LatestNQueue(maxlen=max_buffer) self.fps = fps # None = unlimited, else forward at ~fps @@ -62,7 +71,9 @@ def _task_done(self, task: asyncio.Task) -> None: self._tasks.discard(task) exc = task.exception() if exc: - logger.error("%s: Task failed with exception: %s", self.name, exc, exc_info=exc) + logger.error( + "%s: Task failed with exception: %s", self.name, exc, exc_info=exc + ) if task.cancelled(): return @@ -71,12 +82,14 @@ def _task_done(self, task: asyncio.Task) -> None: async def _producer(self): try: while not self._stopped.is_set(): - frame : Frame = await self.input_track.recv() + frame: Frame = await self.input_track.recv() await self.queue.put_latest(frame) except asyncio.CancelledError: raise except Exception as e: - logger.error("%s: Producer failed with exception: %s", self.name, e, exc_info=True) + logger.error( + "%s: Producer failed with exception: %s", self.name, e, exc_info=True + ) raise # ---------- consumer API (pull one frame; coalesce backlog to newest) ---------- @@ -112,7 +125,7 @@ async def start_event_consumer( """ Starts a task that calls `on_frame(latest_frame)` at ~fps. If fps is None, it forwards as fast as frames arrive (still coalescing). - + Args: on_frame: Callback function to receive frames fps: Frame rate for this consumer (overrides default). None = unlimited. @@ -122,10 +135,12 @@ async def start_event_consumer( # Use consumer-specific fps if provided, otherwise fall back to forwarder's default fps consumer_fps = fps if fps is not None else self.fps consumer_label = consumer_name or "consumer" - + async def _consumer(): loop = asyncio.get_running_loop() - min_interval = (1.0 / consumer_fps) if (consumer_fps and consumer_fps > 0) else 0.0 + min_interval = ( + (1.0 / consumer_fps) if (consumer_fps and consumer_fps > 0) else 0.0 + ) last_ts = 0.0 is_coro = asyncio.iscoroutinefunction(on_frame) frames_forwarded = 0 diff --git a/agents-core/vision_agents/core/vad/events.py b/agents-core/vision_agents/core/vad/events.py index 52ea022a..e8b6bb25 100644 --- a/agents-core/vision_agents/core/vad/events.py +++ b/agents-core/vision_agents/core/vad/events.py @@ -8,7 +8,7 @@ class VADSpeechStartEvent(PluginBaseEvent): """Event emitted when speech begins.""" - type: str = field(default='plugin.vad_speech_start', init=False) + type: str = field(default="plugin.vad_speech_start", init=False) speech_probability: float = 0.0 activation_threshold: float = 0.0 frame_count: int = 1 @@ -19,7 +19,7 @@ class VADSpeechStartEvent(PluginBaseEvent): class VADSpeechEndEvent(PluginBaseEvent): """Event emitted when speech ends.""" - type: str = field(default='plugin.vad_speech_end', init=False) + type: str = field(default="plugin.vad_speech_end", init=False) speech_probability: float = 0.0 deactivation_threshold: float = 0.0 total_speech_duration_ms: float = 0.0 @@ -30,7 +30,7 @@ class VADSpeechEndEvent(PluginBaseEvent): class VADAudioEvent(PluginBaseEvent): """Event emitted when VAD detects complete speech segment.""" - type: str = field(default='plugin.vad_audio', init=False) + type: str = field(default="plugin.vad_audio", init=False) audio_data: Optional[bytes] = None # PCM audio data sample_rate: int = 16000 audio_format: AudioFormat = AudioFormat.PCM_S16 @@ -44,7 +44,7 @@ class VADAudioEvent(PluginBaseEvent): class VADPartialEvent(PluginBaseEvent): """Event emitted during ongoing speech detection.""" - type: str = field(default='plugin.vad_partial', init=False) + type: str = field(default="plugin.vad_partial", init=False) audio_data: Optional[bytes] = None # PCM audio data sample_rate: int = 16000 audio_format: AudioFormat = AudioFormat.PCM_S16 @@ -59,7 +59,7 @@ class VADPartialEvent(PluginBaseEvent): class VADInferenceEvent(PluginBaseEvent): """Event emitted after each VAD inference window.""" - type: str = field(default='plugin.vad_inference', init=False) + type: str = field(default="plugin.vad_inference", init=False) speech_probability: float = 0.0 inference_time_ms: float = 0.0 window_samples: int = 0 @@ -74,7 +74,7 @@ class VADInferenceEvent(PluginBaseEvent): class VADErrorEvent(PluginBaseEvent): """Event emitted when a VAD error occurs.""" - type: str = field(default='plugin.vad_error', init=False) + type: str = field(default="plugin.vad_error", init=False) error: Optional[Exception] = None error_code: Optional[str] = None context: Optional[str] = None diff --git a/agents-core/vision_agents/core/vad/vad.py b/agents-core/vision_agents/core/vad/vad.py index 19ba98e7..142d5058 100644 --- a/agents-core/vision_agents/core/vad/vad.py +++ b/agents-core/vision_agents/core/vad/vad.py @@ -290,7 +290,9 @@ async def _process_frame( frame_bytes = numpy_array_to_bytes(frame.samples) self.speech_buffer.extend(frame_bytes) - async def _flush_speech_buffer(self, user: Optional[Union[Dict[str, Any], Participant]] = None) -> None: + async def _flush_speech_buffer( + self, user: Optional[Union[Dict[str, Any], Participant]] = None + ) -> None: """ Flush the accumulated speech buffer if it meets minimum length requirements. diff --git a/conftest.py b/conftest.py index 5003845c..0b2eb396 100644 --- a/conftest.py +++ b/conftest.py @@ -30,7 +30,7 @@ def assets_dir(): def mia_audio_16khz(): """Load mia.mp3 and convert to 16kHz PCM data.""" audio_file_path = os.path.join(get_assets_dir(), "mia.mp3") - + # Load audio file using PyAV container = av.open(audio_file_path) audio_stream = container.streams.audio[0] @@ -40,11 +40,7 @@ def mia_audio_16khz(): # Create resampler if needed resampler = None if original_sample_rate != target_rate: - resampler = av.AudioResampler( - format='s16', - layout='mono', - rate=target_rate - ) + resampler = av.AudioResampler(format="s16", layout="mono", rate=target_rate) # Read all audio frames samples = [] @@ -68,11 +64,7 @@ def mia_audio_16khz(): container.close() # Create PCM data - pcm = PcmData( - samples=samples, - sample_rate=target_rate, - format="s16" - ) + pcm = PcmData(samples=samples, sample_rate=target_rate, format="s16") return pcm @@ -81,7 +73,7 @@ def mia_audio_16khz(): async def bunny_video_track(): """Create RealVideoTrack from video file.""" from aiortc import VideoStreamTrack - + video_file_path = os.path.join(get_assets_dir(), "bunny_3s.mp4") class RealVideoTrack(VideoStreamTrack): @@ -101,12 +93,12 @@ async def recv(self): for frame in self.container.decode(self.video_stream): if frame is None: raise asyncio.CancelledError("End of video stream") - + self.frame_count += 1 frame = frame.to_rgb() await asyncio.sleep(self.frame_duration) return frame - + raise asyncio.CancelledError("End of video stream") except asyncio.CancelledError: @@ -123,4 +115,3 @@ async def recv(self): yield track finally: track.container.close() - diff --git a/examples/01_simple_agent_example/simple_agent_example.py b/examples/01_simple_agent_example/simple_agent_example.py index 3b063e66..3009ee69 100644 --- a/examples/01_simple_agent_example/simple_agent_example.py +++ b/examples/01_simple_agent_example/simple_agent_example.py @@ -7,20 +7,25 @@ load_dotenv() + async def start_agent() -> None: llm = openai.LLM(model="gpt-4o-mini") # create an agent to run with Stream's edge, openAI llm agent = Agent( edge=getstream.Edge(), # low latency edge. clients for React, iOS, Android, RN, Flutter etc. - agent_user=User(name="My happy AI friend", id="agent"), # the user object for the agent (name, image etc) + agent_user=User( + name="My happy AI friend", id="agent" + ), # the user object for the agent (name, image etc) instructions="You're a voice AI assistant. Keep responses short and conversational. Don't use special characters or formatting. Be friendly and helpful.", processors=[], # processors can fetch extra data, check images/audio data or transform video # llm with tts & stt. if you use a realtime (sts capable) llm the tts, stt and vad aren't needed llm=llm, tts=cartesia.TTS(), stt=deepgram.STT(), - turn_detection=smart_turn.TurnDetection(buffer_duration=2.0, confidence_threshold=0.5), # Enable turn detection with FAL/ Smart turn - #vad=silero.VAD(), + turn_detection=smart_turn.TurnDetection( + buffer_duration=2.0, confidence_threshold=0.5 + ), # Enable turn detection with FAL/ Smart turn + # vad=silero.VAD(), # realtime version (vad, tts and stt not needed) # llm=openai.Realtime() ) @@ -37,15 +42,15 @@ async def start_agent() -> None: # Example 1: standardized simple response # await agent.llm.simple_response("chat with the user about the weather.") # Example 2: use native openAI create response - # await llm.create_response(input=[ - # { - # "role": "user", - # "content": [ - # {"type": "input_text", "text": "Tell me a short poem about this image"}, - # {"type": "input_image", "image_url": f"https://images.unsplash.com/photo-1757495361144-0c2bfba62b9e?q=80&w=2340&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"}, - # ], - # } - # ],) + # await llm.create_response(input=[ + # { + # "role": "user", + # "content": [ + # {"type": "input_text", "text": "Tell me a short poem about this image"}, + # {"type": "input_image", "image_url": f"https://images.unsplash.com/photo-1757495361144-0c2bfba62b9e?q=80&w=2340&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"}, + # ], + # } + # ],) # run till the call ends await agent.finish() diff --git a/examples/02_golf_coach_example/golf_coach_example.py b/examples/02_golf_coach_example/golf_coach_example.py index bb74f9e4..7b72d693 100644 --- a/examples/02_golf_coach_example/golf_coach_example.py +++ b/examples/02_golf_coach_example/golf_coach_example.py @@ -11,12 +11,14 @@ async def start_agent() -> None: agent = Agent( - edge=getstream.Edge(), # use stream for edge video transport + edge=getstream.Edge(), # use stream for edge video transport agent_user=User(name="AI golf coach"), - instructions="Read @golf_coach.md", # read the golf coach markdown instructions - llm=gemini.Realtime(fps=10), # Careful with FPS can get expensive + instructions="Read @golf_coach.md", # read the golf coach markdown instructions + llm=gemini.Realtime(fps=10), # Careful with FPS can get expensive # llm=openai.Realtime(fps=10), use this to switch to openai - processors=[ultralytics.YOLOPoseProcessor(model_path="yolo11n-pose.pt")], # realtime pose detection with yolo + processors=[ + ultralytics.YOLOPoseProcessor(model_path="yolo11n-pose.pt") + ], # realtime pose detection with yolo ) await agent.create_user() @@ -28,7 +30,9 @@ async def start_agent() -> None: with await agent.join(call): await agent.edge.open_demo(call) # all LLMs support a simple_response method and a more advanced native method (so you can always use the latest LLM features) - await agent.llm.simple_response(text="Say hi. After the user does their golf swing offer helpful feedback.") + await agent.llm.simple_response( + text="Say hi. After the user does their golf swing offer helpful feedback." + ) # Gemini's native API is available here # agent.llm.send_realtime_input(text="Hello world") await agent.finish() # run till the call ends diff --git a/examples/other_examples/07_function_calling_example/claude_example.py b/examples/other_examples/07_function_calling_example/claude_example.py index 0efb8693..5613a499 100644 --- a/examples/other_examples/07_function_calling_example/claude_example.py +++ b/examples/other_examples/07_function_calling_example/claude_example.py @@ -11,10 +11,10 @@ async def main(): """Run a simple Claude function calling example.""" - + # Create the LLM llm = anthropic.LLM(model="claude-3-5-sonnet-20241022") - + # Register functions @llm.register_function(description="Get current weather for a location") def get_weather(location: str): @@ -23,35 +23,34 @@ def get_weather(location: str): "location": location, "temperature": "22°C", "condition": "Sunny", - "humidity": "65%" + "humidity": "65%", } - + @llm.register_function(description="Calculate the sum of two numbers") def calculate_sum(a: int, b: int): """Calculate the sum of two numbers.""" return a + b - + # Test queries queries = [ "What's the weather like in New York?", "Calculate 15 + 27 for me", - "What's the weather in London and calculate 100 + 200?" + "What's the weather in London and calculate 100 + 200?", ] - + print("Claude Function Calling Example") print("=" * 50) - + for query in queries: print(f"\nQuery: {query}") print("-" * 30) - + response = await llm.create_message( - messages=[{"role": "user", "content": query}], - max_tokens=1000 + messages=[{"role": "user", "content": query}], max_tokens=1000 ) - + print(f"Response: {response.text}") if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/examples/other_examples/07_function_calling_example/gemini_example.py b/examples/other_examples/07_function_calling_example/gemini_example.py index 0a67d126..10badd6c 100644 --- a/examples/other_examples/07_function_calling_example/gemini_example.py +++ b/examples/other_examples/07_function_calling_example/gemini_example.py @@ -11,10 +11,10 @@ async def main(): """Run a simple Gemini function calling example.""" - + # Create the LLM llm = gemini.LLM("gemini-2.0-flash") - + # Register functions @llm.register_function(description="Get current weather for a location") def get_weather(location: str): @@ -23,34 +23,32 @@ def get_weather(location: str): "location": location, "temperature": "22°C", "condition": "Sunny", - "humidity": "65%" + "humidity": "65%", } - + @llm.register_function(description="Calculate the sum of two numbers") def calculate_sum(a: int, b: int): """Calculate the sum of two numbers.""" return a + b - + # Test queries queries = [ "What's the weather like in New York?", "Calculate 15 + 27 for me", - "What's the weather in London and calculate 100 + 200?" + "What's the weather in London and calculate 100 + 200?", ] - + print("Gemini Function Calling Example") print("=" * 50) - + for query in queries: print(f"\nQuery: {query}") print("-" * 30) - - response = await llm.send_message( - message=query - ) - + + response = await llm.send_message(message=query) + print(f"Response: {response.text}") if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/examples/other_examples/07_function_calling_example/openai_example.py b/examples/other_examples/07_function_calling_example/openai_example.py index 26fa42f2..68e7926b 100644 --- a/examples/other_examples/07_function_calling_example/openai_example.py +++ b/examples/other_examples/07_function_calling_example/openai_example.py @@ -11,10 +11,10 @@ async def main(): """Run a simple OpenAI function calling example.""" - + # Create the LLM llm = openai.LLM(model="gpt-4o-mini") - + # Register functions @llm.register_function(description="Get current weather for a location") def get_weather(location: str): @@ -23,33 +23,32 @@ def get_weather(location: str): "location": location, "temperature": "22°C", "condition": "Sunny", - "humidity": "65%" + "humidity": "65%", } - + @llm.register_function(description="Calculate the sum of two numbers") def calculate_sum(a: int, b: int): """Calculate the sum of two numbers.""" return a + b - + # Test queries queries = [ "What's the weather like in New York?", "Calculate 15 + 27 for me", - "What's the weather in London and calculate 100 + 200?" + "What's the weather in London and calculate 100 + 200?", ] - + print("OpenAI Function Calling Example") print("=" * 50) - + for query in queries: print(f"\nQuery: {query}") print("-" * 30) - + response = await llm.create_response( - model="gpt-4o-mini", - input=[{"role": "user", "content": query}] + model="gpt-4o-mini", input=[{"role": "user", "content": query}] ) - + print(f"Response: {response.text}") diff --git a/examples/other_examples/09_github_mcp_demo/gemini_realtime_github_mcp_demo.py b/examples/other_examples/09_github_mcp_demo/gemini_realtime_github_mcp_demo.py index c7ae87e7..2178717c 100644 --- a/examples/other_examples/09_github_mcp_demo/gemini_realtime_github_mcp_demo.py +++ b/examples/other_examples/09_github_mcp_demo/gemini_realtime_github_mcp_demo.py @@ -29,39 +29,38 @@ async def main(): """Demonstrate Gemini Realtime with GitHub MCP server integration.""" - + # Get GitHub PAT from environment github_pat = os.getenv("GITHUB_PAT") if not github_pat: logger.error("GITHUB_PAT environment variable not found!") logger.error("Please set GITHUB_PAT in your .env file or environment") return - + # Get Google API key from environment google_api_key = os.getenv("GOOGLE_API_KEY") if not google_api_key: logger.error("GOOGLE_API_KEY environment variable not found!") logger.error("Please set GOOGLE_API_KEY in your .env file or environment") return - + # Create GitHub MCP server github_server = MCPServerRemote( url="https://api.githubcopilot.com/mcp/", headers={"Authorization": f"Bearer {github_pat}"}, timeout=10.0, # Shorter connection timeout - session_timeout=300.0 + session_timeout=300.0, ) - + # Create Gemini Realtime LLM llm = Realtime( - model="gemini-2.5-flash-native-audio-preview-09-2025", - api_key=google_api_key + model="gemini-2.5-flash-native-audio-preview-09-2025", api_key=google_api_key ) - + # Create real edge transport and agent user edge = getstream.Edge() agent_user = User(name="GitHub AI Assistant", id="github-agent") - + # Create agent with GitHub MCP server and Gemini Realtime LLM agent = Agent( edge=edge, @@ -71,50 +70,61 @@ async def main(): processors=[], mcp_servers=[github_server], ) - + logger.info("Agent created with Gemini Realtime and GitHub MCP server") logger.info(f"GitHub server: {github_server}") - + try: # Create the agent user await agent.create_user() - + # Set up event handler for when participants join @agent.subscribe async def on_participant_joined(event: CallSessionParticipantJoinedEvent): # Check MCP tools after connection available_functions = agent.llm.get_available_functions() - mcp_functions = [f for f in available_functions if f['name'].startswith('mcp_')] - logger.info(f"✅ Found {len(mcp_functions)} MCP tools available for function calling") - await agent.simple_response(f"Hello {event.participant.user.name}! I'm your GitHub AI assistant powered by Gemini Live. I have access to {len(mcp_functions)} GitHub tools and can help you with repositories, issues, pull requests, and more through voice commands!") - + mcp_functions = [ + f for f in available_functions if f["name"].startswith("mcp_") + ] + logger.info( + f"✅ Found {len(mcp_functions)} MCP tools available for function calling" + ) + await agent.simple_response( + f"Hello {event.participant.user.name}! I'm your GitHub AI assistant powered by Gemini Live. I have access to {len(mcp_functions)} GitHub tools and can help you with repositories, issues, pull requests, and more through voice commands!" + ) + # Create a call call = agent.edge.client.video.call("default", str(uuid4())) - + # Have the agent join the call/room logger.info("🎤 Agent joining call...") with await agent.join(call): # Open the demo UI logger.info("🌐 Opening browser with demo UI...") await agent.edge.open_demo(call) - logger.info("✅ Agent is now live with Gemini Realtime! You can talk to it in the browser.") + logger.info( + "✅ Agent is now live with Gemini Realtime! You can talk to it in the browser." + ) logger.info("Try asking:") logger.info(" - 'What repositories do I have?'") logger.info(" - 'Create a new issue in my repository'") logger.info(" - 'Search for issues with the label bug'") logger.info(" - 'Show me recent pull requests'") logger.info("") - logger.info("The agent will use Gemini Live's real-time function calling to interact with GitHub!") - + logger.info( + "The agent will use Gemini Live's real-time function calling to interact with GitHub!" + ) + # Run until the call ends await agent.finish() - + except Exception as e: logger.error(f"Error with Gemini Realtime GitHub MCP demo: {e}") logger.error("Make sure your GITHUB_PAT and GOOGLE_API_KEY are valid") import traceback + traceback.print_exc() - + # Clean up await agent.close() logger.info("Demo completed!") diff --git a/examples/other_examples/09_github_mcp_demo/github_mcp_demo.py b/examples/other_examples/09_github_mcp_demo/github_mcp_demo.py index 20c9ac5c..560b38ba 100644 --- a/examples/other_examples/09_github_mcp_demo/github_mcp_demo.py +++ b/examples/other_examples/09_github_mcp_demo/github_mcp_demo.py @@ -29,36 +29,36 @@ async def main(): """Demonstrate GitHub MCP server integration.""" - + # Get GitHub PAT from environment github_pat = os.getenv("GITHUB_PAT") if not github_pat: logger.error("GITHUB_PAT environment variable not found!") logger.error("Please set GITHUB_PAT in your .env file or environment") return - + # Create GitHub MCP server github_server = MCPServerRemote( url="https://api.githubcopilot.com/mcp/", headers={"Authorization": f"Bearer {github_pat}"}, timeout=10.0, # Shorter connection timeout - session_timeout=300.0 + session_timeout=300.0, ) - + # Get OpenAI API key from environment openai_api_key = os.getenv("OPENAI_API_KEY") if not openai_api_key: logger.error("OPENAI_API_KEY environment variable not found!") logger.error("Please set OPENAI_API_KEY in your .env file or environment") return - + # Create OpenAI LLM llm = OpenAILLM(model="gpt-4o", api_key=openai_api_key) - + # Create real edge transport and agent user edge = getstream.Edge() agent_user = User(name="GitHub AI Assistant", id="github-agent") - + # Create agent with GitHub MCP server and OpenAI LLM agent = Agent( edge=edge, @@ -69,12 +69,12 @@ async def main(): mcp_servers=[github_server], tts=elevenlabs.TTS(), stt=deepgram.STT(), - vad=silero.VAD() + vad=silero.VAD(), ) - + logger.info("Agent created with GitHub MCP server") logger.info(f"GitHub server: {github_server}") - + try: # Connect to GitHub MCP server with timeout logger.info("Connecting to GitHub MCP server...") @@ -82,22 +82,26 @@ async def main(): # Check if MCP tools were registered with the function registry logger.info("Checking function registry for MCP tools...") available_functions = agent.llm.get_available_functions() - mcp_functions = [f for f in available_functions if f['name'].startswith('mcp_')] - - logger.info(f"✅ Found {len(mcp_functions)} MCP tools registered in function registry") + mcp_functions = [f for f in available_functions if f["name"].startswith("mcp_")] + + logger.info( + f"✅ Found {len(mcp_functions)} MCP tools registered in function registry" + ) logger.info("MCP tools are now available to the LLM for function calling!") - + # Create the agent user await agent.create_user() - + # Set up event handler for when participants join @agent.subscribe async def on_participant_joined(event: CallSessionParticipantJoinedEvent): - await agent.say(f"Hello {event.participant.user.name}! I'm your GitHub AI assistant with access to {len(mcp_functions)} GitHub tools. I can help you with repositories, issues, pull requests, and more!") - + await agent.say( + f"Hello {event.participant.user.name}! I'm your GitHub AI assistant with access to {len(mcp_functions)} GitHub tools. I can help you with repositories, issues, pull requests, and more!" + ) + # Create a call call = agent.edge.client.video.call("default", str(uuid4())) - + # Have the agent join the call/room logger.info("🎤 Agent joining call...") with await agent.join(call): @@ -106,17 +110,20 @@ async def on_participant_joined(event: CallSessionParticipantJoinedEvent): await agent.edge.open_demo(call) logger.info("✅ Agent is now live! You can talk to it in the browser.") - logger.info("Try asking: 'What repositories do I have?' or 'Create a new issue'") - + logger.info( + "Try asking: 'What repositories do I have?' or 'Create a new issue'" + ) + # Run until the call ends await agent.finish() - + except Exception as e: logger.error(f"Error with GitHub MCP server: {e}") logger.error("Make sure your GITHUB_PAT and OPENAI_API_KEY are valid") import traceback + traceback.print_exc() - + # Clean up await agent.close() logger.info("Demo completed!") diff --git a/examples/other_examples/09_github_mcp_demo/openai_realtime_github_mcp_demo.py b/examples/other_examples/09_github_mcp_demo/openai_realtime_github_mcp_demo.py index 60121559..71c6fe70 100644 --- a/examples/other_examples/09_github_mcp_demo/openai_realtime_github_mcp_demo.py +++ b/examples/other_examples/09_github_mcp_demo/openai_realtime_github_mcp_demo.py @@ -29,38 +29,36 @@ async def main(): """Demonstrate OpenAI Realtime with GitHub MCP server integration.""" - + # Get GitHub PAT from environment github_pat = os.getenv("GITHUB_PAT") if not github_pat: logger.error("GITHUB_PAT environment variable not found!") logger.error("Please set GITHUB_PAT in your .env file or environment") return - + # Check OpenAI API key from environment openai_api_key = os.getenv("OPENAI_API_KEY") if not openai_api_key: logger.error("OPENAI_API_KEY environment variable not found!") logger.error("Please set OPENAI_API_KEY in your .env file or environment") return - + # Create GitHub MCP server github_server = MCPServerRemote( url="https://api.githubcopilot.com/mcp/", headers={"Authorization": f"Bearer {github_pat}"}, timeout=10.0, # Shorter connection timeout - session_timeout=300.0 + session_timeout=300.0, ) - + # Create OpenAI Realtime LLM (uses OPENAI_API_KEY from environment) - llm = Realtime( - model="gpt-4o-realtime-preview-2024-12-17" - ) - + llm = Realtime(model="gpt-4o-realtime-preview-2024-12-17") + # Create real edge transport and agent user edge = getstream.Edge() agent_user = User(name="GitHub AI Assistant", id="github-agent") - + # Create agent with GitHub MCP server and Gemini Realtime LLM agent = Agent( edge=edge, @@ -70,50 +68,61 @@ async def main(): processors=[], mcp_servers=[github_server], ) - + logger.info("Agent created with OpenAI Realtime and GitHub MCP server") logger.info(f"GitHub server: {github_server}") - + try: # Create the agent user await agent.create_user() - + # Set up event handler for when participants join @agent.subscribe async def on_participant_joined(event: CallSessionParticipantJoinedEvent): # Check MCP tools after connection available_functions = agent.llm.get_available_functions() - mcp_functions = [f for f in available_functions if f['name'].startswith('mcp_')] - logger.info(f"✅ Found {len(mcp_functions)} MCP tools available for function calling") - await agent.say(f"Hello {event.participant.user.name}! I'm your GitHub AI assistant powered by OpenAI Realtime. I have access to {len(mcp_functions)} GitHub tools and can help you with repositories, issues, pull requests, and more through voice commands!") - + mcp_functions = [ + f for f in available_functions if f["name"].startswith("mcp_") + ] + logger.info( + f"✅ Found {len(mcp_functions)} MCP tools available for function calling" + ) + await agent.say( + f"Hello {event.participant.user.name}! I'm your GitHub AI assistant powered by OpenAI Realtime. I have access to {len(mcp_functions)} GitHub tools and can help you with repositories, issues, pull requests, and more through voice commands!" + ) + # Create a call call = agent.edge.client.video.call("default", str(uuid4())) - + # Have the agent join the call/room logger.info("🎤 Agent joining call...") with await agent.join(call): # Open the demo UI logger.info("🌐 Opening browser with demo UI...") await agent.edge.open_demo(call) - logger.info("✅ Agent is now live with OpenAI Realtime! You can talk to it in the browser.") + logger.info( + "✅ Agent is now live with OpenAI Realtime! You can talk to it in the browser." + ) logger.info("Try asking:") logger.info(" - 'What repositories do I have?'") logger.info(" - 'Create a new issue in my repository'") logger.info(" - 'Search for issues with the label bug'") logger.info(" - 'Show me recent pull requests'") logger.info("") - logger.info("The agent will use OpenAI Realtime's real-time function calling to interact with GitHub!") - + logger.info( + "The agent will use OpenAI Realtime's real-time function calling to interact with GitHub!" + ) + # Run until the call ends await agent.finish() - + except Exception as e: logger.error(f"Error with OpenAI Realtime GitHub MCP demo: {e}") logger.error("Make sure your GITHUB_PAT and OPENAI_API_KEY are valid") import traceback + traceback.print_exc() - + # Clean up await agent.close() logger.info("Demo completed!") diff --git a/examples/other_examples/gemini_live_realtime/gemini_live_example.py b/examples/other_examples/gemini_live_realtime/gemini_live_example.py index cf328221..6728e48c 100644 --- a/examples/other_examples/gemini_live_realtime/gemini_live_example.py +++ b/examples/other_examples/gemini_live_realtime/gemini_live_example.py @@ -9,7 +9,10 @@ load_dotenv() -logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s [call_id=%(call_id)s] %(name)s: %(message)s") +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s [call_id=%(call_id)s] %(name)s: %(message)s", +) logger = logging.getLogger(__name__) diff --git a/examples/other_examples/openai_realtime_webrtc/openai_realtime_example.py b/examples/other_examples/openai_realtime_webrtc/openai_realtime_example.py index 645a425f..d84cbcb8 100644 --- a/examples/other_examples/openai_realtime_webrtc/openai_realtime_example.py +++ b/examples/other_examples/openai_realtime_webrtc/openai_realtime_example.py @@ -26,7 +26,7 @@ async def start_agent() -> None: # Set the call ID here to be used in the logging call_id = str(uuid4()) - + # create a stream client and a user object client = AsyncStream() agent_user = await client.create_user(name="My happy AI friend") @@ -35,7 +35,8 @@ async def start_agent() -> None: agent = Agent( edge=getstream.Edge(), # low latency edge. clients for React, iOS, Android, RN, Flutter etc. agent_user=agent_user, # the user object for the agent (name, image etc) - instructions=(""" + instructions=( + """ You are a voice assistant. - Greet the user once when asked, then wait for the next user input. - Speak English only. @@ -45,8 +46,8 @@ async def start_agent() -> None: - Only respond to clear audio or text. - If the user's audio is not clear (e.g., ambiguous input/background noise/silent/unintelligible) or you didn't fully understand, ask for clarification. """ - ), - # Enable video input and set a conservative default frame rate for realtime responsiveness + ), + # Enable video input and set a conservative default frame rate for realtime responsiveness llm=openai.Realtime(), processors=[], # processors can fetch extra data, check images/audio data or transform video ) @@ -63,7 +64,7 @@ async def start_agent() -> None: logger.info("Joining call") await agent.edge.open_demo(call) logger.info("LLM ready") - #await agent.llm.request_session_info() + # await agent.llm.request_session_info() logger.info("Requested session info") # Wait for a human to join the call before greeting logger.info("Waiting for human to join the call") diff --git a/examples/other_examples/plugins_examples/audio_moderation/main.py b/examples/other_examples/plugins_examples/audio_moderation/main.py index 567bf1d9..8ca574aa 100644 --- a/examples/other_examples/plugins_examples/audio_moderation/main.py +++ b/examples/other_examples/plugins_examples/audio_moderation/main.py @@ -84,7 +84,7 @@ async def handle_transcript(event: STTTranscriptEvent): if event.user_metadata: user = event.user_metadata user_info = user.name if user.name else str(user) - + print(f"[{timestamp}] {user_info}: {event.text}") if event.confidence: print(f" └─ confidence: {event.confidence:.2%}") @@ -92,9 +92,7 @@ async def handle_transcript(event: STTTranscriptEvent): print(f" └─ processing time: {event.processing_time_ms:.1f}ms") # Moderation check (executed in a background thread to avoid blocking) - moderation = await asyncio.to_thread( - moderate, client, event.text, user_info - ) + moderation = await asyncio.to_thread(moderate, client, event.text, user_info) print( f" └─ moderation recommended action: {moderation.recommended_action} for transcript: {event.text}" ) @@ -145,7 +143,7 @@ def setup_moderation_config(client: Stream): print("=" * 55) args = parse_args() - + if args.setup: client = Stream.from_env() setup_moderation_config(client) diff --git a/examples/other_examples/plugins_examples/mcp/main.py b/examples/other_examples/plugins_examples/mcp/main.py index f348bdbe..248d274b 100644 --- a/examples/other_examples/plugins_examples/mcp/main.py +++ b/examples/other_examples/plugins_examples/mcp/main.py @@ -24,13 +24,14 @@ from vision_agents.plugins import deepgram, elevenlabs, openai, getstream from vision_agents.core.mcp import MCPBaseServer + # Example MCP server for demonstration class ExampleMCPServer(MCPBaseServer): """Example MCP server that provides weather information.""" - + def __init__(self): super().__init__("example-server") - + async def get_tools(self): """Return available tools.""" return [ @@ -42,11 +43,11 @@ async def get_tools(self): "properties": { "location": {"type": "string", "description": "City name"} }, - "required": ["location"] - } + "required": ["location"], + }, } ] - + async def call_tool(self, name: str, arguments: dict): """Execute a tool call.""" if name == "get_weather": @@ -54,8 +55,10 @@ async def call_tool(self, name: str, arguments: dict): return f"The weather in {location} is sunny and 72°F" return "Tool not found" + load_dotenv() + async def main(): # Create agent with MCP servers agent = Agent( @@ -74,8 +77,11 @@ async def main(): # Join call and start MCP-enabled conversation with await agent.join(call): - await agent.say("Hello! I have access to MCP tools including weather information. How can I help you?") + await agent.say( + "Hello! I have access to MCP tools including weather information. How can I help you?" + ) await agent.finish() + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/examples/other_examples/plugins_examples/stt_deepgram_transcription/main.py b/examples/other_examples/plugins_examples/stt_deepgram_transcription/main.py index 4c7ec779..e2308b07 100644 --- a/examples/other_examples/plugins_examples/stt_deepgram_transcription/main.py +++ b/examples/other_examples/plugins_examples/stt_deepgram_transcription/main.py @@ -29,6 +29,7 @@ load_dotenv() + async def main(): # Create agent with STT + LLM + TTS for conversation agent = Agent( @@ -53,8 +54,10 @@ async def on_my_transcript(event: STTTranscriptEvent): if event.confidence: agent.logger.info(f" └─ confidence: {event.confidence:.2%}") if event.processing_time_ms: - agent.logger.info(f" └─ processing time: {event.processing_time_ms:.1f}ms") - + agent.logger.info( + f" └─ processing time: {event.processing_time_ms:.1f}ms" + ) + # Generate a response to the transcribed text await agent.simple_response(event.text) @@ -73,15 +76,17 @@ async def handle_stt_error(event: STTErrorEvent): if event.context: agent.logger.error(f" └─ context: {event.context}") - # Create call and open demo call = agent.edge.client.video.call("default", str(uuid4())) agent.edge.open_demo(call) # Join call and start conversation with await agent.join(call): - await agent.say("Hello! I'm your transcription bot. I'll listen to what you say, transcribe it, and respond to you. Try saying something!") + await agent.say( + "Hello! I'm your transcription bot. I'll listen to what you say, transcribe it, and respond to you. Try saying something!" + ) await agent.finish() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/other_examples/plugins_examples/stt_moonshine_transcription/main.py b/examples/other_examples/plugins_examples/stt_moonshine_transcription/main.py index d3daaa4d..0f2a8bcf 100644 --- a/examples/other_examples/plugins_examples/stt_moonshine_transcription/main.py +++ b/examples/other_examples/plugins_examples/stt_moonshine_transcription/main.py @@ -27,6 +27,7 @@ load_dotenv() + async def main(): # Create agent with STT + LLM for conversation agent = Agent( @@ -45,7 +46,7 @@ async def handle_transcript(event: STTTranscriptEvent): if event.user_metadata: user = event.user_metadata user_info = user.name if user.name else str(user) - + print(f"[{event.timestamp}] {user_info}: {event.text}") if event.confidence: print(f" └─ confidence: {event.confidence:.2%}") @@ -65,8 +66,11 @@ async def handle_stt_error(event: STTErrorEvent): # Join call and start conversation with await agent.join(call): - await agent.simple_response("Hello! I can transcribe your speech and respond to you.") + await agent.simple_response( + "Hello! I can transcribe your speech and respond to you." + ) await agent.finish() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/other_examples/plugins_examples/tts_cartesia/main.py b/examples/other_examples/plugins_examples/tts_cartesia/main.py index 9da31467..63042345 100644 --- a/examples/other_examples/plugins_examples/tts_cartesia/main.py +++ b/examples/other_examples/plugins_examples/tts_cartesia/main.py @@ -29,6 +29,7 @@ load_dotenv() + async def main(): # Create agent with TTS agent = Agent( @@ -42,12 +43,16 @@ async def main(): # Subscribe to participant joined events @agent.subscribe async def handle_participant_joined(event: CallSessionParticipantJoinedEvent): - await agent.simple_response(f"Hello {event.participant.user.name}! Welcome to the call.") + await agent.simple_response( + f"Hello {event.participant.user.name}! Welcome to the call." + ) # Subscribe to TTS events @agent.subscribe async def handle_tts_audio(event: TTSAudioEvent): - print(f"TTS audio generated: {event.chunk_index} chunks, final: {event.is_final_chunk}") + print( + f"TTS audio generated: {event.chunk_index} chunks, final: {event.is_final_chunk}" + ) # Subscribe to TTS error events @agent.subscribe @@ -64,5 +69,6 @@ async def handle_tts_error(event: TTSErrorEvent): with await agent.join(call): await agent.finish() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/other_examples/plugins_examples/tts_elevenlabs/main.py b/examples/other_examples/plugins_examples/tts_elevenlabs/main.py index 58290892..27b432ff 100644 --- a/examples/other_examples/plugins_examples/tts_elevenlabs/main.py +++ b/examples/other_examples/plugins_examples/tts_elevenlabs/main.py @@ -29,6 +29,7 @@ load_dotenv() + async def main(): # Create agent with TTS agent = Agent( @@ -42,12 +43,16 @@ async def main(): # Subscribe to participant joined events @agent.subscribe async def handle_participant_joined(event: CallSessionParticipantJoinedEvent): - await agent.simple_response(f"Hello {event.participant.user.name}! Welcome to the call.") + await agent.simple_response( + f"Hello {event.participant.user.name}! Welcome to the call." + ) # Subscribe to TTS events @agent.subscribe async def handle_tts_audio(event: TTSAudioEvent): - print(f"TTS audio generated: {event.chunk_index} chunks, final: {event.is_final_chunk}") + print( + f"TTS audio generated: {event.chunk_index} chunks, final: {event.is_final_chunk}" + ) # Subscribe to TTS error events @agent.subscribe @@ -64,5 +69,6 @@ async def handle_tts_error(event: TTSErrorEvent): with await agent.join(call): await agent.finish() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/other_examples/plugins_examples/tts_kokoro/main.py b/examples/other_examples/plugins_examples/tts_kokoro/main.py index b36ad008..f9771a60 100644 --- a/examples/other_examples/plugins_examples/tts_kokoro/main.py +++ b/examples/other_examples/plugins_examples/tts_kokoro/main.py @@ -29,6 +29,7 @@ load_dotenv() + async def main(): # Create agent with TTS agent = Agent( @@ -42,12 +43,16 @@ async def main(): # Subscribe to participant joined events @agent.subscribe async def handle_participant_joined(event: CallSessionParticipantJoinedEvent): - await agent.simple_response(f"Hello {event.participant.user.name}! Welcome to the call.") + await agent.simple_response( + f"Hello {event.participant.user.name}! Welcome to the call." + ) # Subscribe to TTS events @agent.subscribe async def handle_tts_audio(event: TTSAudioEvent): - print(f"TTS audio generated: {event.chunk_index} chunks, final: {event.is_final_chunk}") + print( + f"TTS audio generated: {event.chunk_index} chunks, final: {event.is_final_chunk}" + ) # Subscribe to TTS error events @agent.subscribe @@ -64,5 +69,6 @@ async def handle_tts_error(event: TTSErrorEvent): with await agent.join(call): await agent.finish() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/other_examples/plugins_examples/vad_silero/main.py b/examples/other_examples/plugins_examples/vad_silero/main.py index 61d1f517..51d5a723 100644 --- a/examples/other_examples/plugins_examples/vad_silero/main.py +++ b/examples/other_examples/plugins_examples/vad_silero/main.py @@ -24,6 +24,7 @@ load_dotenv() + async def main(): # Create agent with VAD + LLM for conversation agent = Agent( @@ -42,8 +43,10 @@ async def handle_speech_detected(event: VADAudioEvent): if event.user_metadata: user = event.user_metadata user_info = user.name if user.name else str(user) - - print(f"Speech detected from user: {user_info} - duration: {event.duration_ms:.2f}ms") + + print( + f"Speech detected from user: {user_info} - duration: {event.duration_ms:.2f}ms" + ) # Subscribe to VAD error events @agent.subscribe @@ -58,8 +61,11 @@ async def handle_vad_error(event: VADErrorEvent): # Join call and start conversation with await agent.join(call): - await agent.simple_response("Hello! I can detect when you speak and respond to you.") + await agent.simple_response( + "Hello! I can detect when you speak and respond to you." + ) await agent.finish() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/other_examples/plugins_examples/video_moderation/main.py b/examples/other_examples/plugins_examples/video_moderation/main.py index b83f6dc2..ab176d88 100644 --- a/examples/other_examples/plugins_examples/video_moderation/main.py +++ b/examples/other_examples/plugins_examples/video_moderation/main.py @@ -105,7 +105,7 @@ def moderate(client: Stream, text: str, user_name: str) -> CheckResponse: async def main(): # Load environment variables load_dotenv() - + # Initialize Stream client from ENV client = Stream.from_env() @@ -147,7 +147,7 @@ async def handle_transcript(event: STTTranscriptEvent): if event.user_metadata: user = event.user_metadata user_info = user.name if user.name else str(user) - + print(f"[{timestamp}] {user_info}: {event.text}") if event.confidence: print(f" └─ confidence: {event.confidence:.2%}") @@ -155,9 +155,7 @@ async def handle_transcript(event: STTTranscriptEvent): print(f" └─ processing time: {event.processing_time_ms:.1f}ms") # Moderation check (executed in a background thread to avoid blocking) - moderation = await asyncio.to_thread( - moderate, client, event.text, user_info - ) + moderation = await asyncio.to_thread(moderate, client, event.text, user_info) print( f" └─ moderation recommended action: {moderation.recommended_action} for transcript: {event.text}" ) @@ -184,6 +182,7 @@ async def handle_stt_error(event: STTErrorEvent): except Exception as e: print(f"❌ Error: {e}") import traceback + traceback.print_exc() finally: client.delete_users([user_id]) @@ -219,7 +218,7 @@ def setup_moderation_config(client: Stream): print("=" * 55) args = parse_args() - + if args.setup: client = Stream.from_env() setup_moderation_config(client) diff --git a/examples/other_examples/plugins_examples/wizper_stt_translate/main.py b/examples/other_examples/plugins_examples/wizper_stt_translate/main.py index c3f9314b..e0869116 100644 --- a/examples/other_examples/plugins_examples/wizper_stt_translate/main.py +++ b/examples/other_examples/plugins_examples/wizper_stt_translate/main.py @@ -127,7 +127,9 @@ async def handle_speech_detected(event: VADAudioEvent): if event.user_metadata: user = event.user_metadata user_info = user.name if user.name else str(user) - print(f"{time.time()} Speech detected from user: {user_info} duration {event.duration_ms:.2f}ms") + print( + f"{time.time()} Speech detected from user: {user_info} duration {event.duration_ms:.2f}ms" + ) # Subscribe to transcript events @agent.subscribe @@ -137,7 +139,7 @@ async def handle_transcript(event: STTTranscriptEvent): if event.user_metadata: user = event.user_metadata user_info = user.name if user.name else str(user) - + print(f"[{timestamp}] {user_info}: {event.text}") if event.confidence: print(f" └─ confidence: {event.confidence:.2%}") @@ -173,6 +175,7 @@ async def handle_vad_error(event: VADErrorEvent): except Exception as e: print(f"❌ Error: {e}") import traceback + traceback.print_exc() finally: client.delete_users([user_id]) diff --git a/plugins/anthropic/tests/test_anthropic_llm.py b/plugins/anthropic/tests/test_anthropic_llm.py index feef5e0b..416dc09f 100644 --- a/plugins/anthropic/tests/test_anthropic_llm.py +++ b/plugins/anthropic/tests/test_anthropic_llm.py @@ -58,7 +58,7 @@ async def test_native_api(self, llm: ClaudeLLM): @pytest.mark.integration async def test_stream(self, llm: ClaudeLLM): streamingWorks = False - + @llm.events.subscribe async def passed(event: LLMResponseChunkEvent): nonlocal streamingWorks @@ -70,7 +70,6 @@ async def passed(event: LLMResponseChunkEvent): assert streamingWorks - @pytest.mark.integration async def test_memory(self, llm: ClaudeLLM): await llm.simple_response( diff --git a/plugins/anthropic/vision_agents/plugins/anthropic/anthropic_llm.py b/plugins/anthropic/vision_agents/plugins/anthropic/anthropic_llm.py index 30691576..2f04b580 100644 --- a/plugins/anthropic/vision_agents/plugins/anthropic/anthropic_llm.py +++ b/plugins/anthropic/vision_agents/plugins/anthropic/anthropic_llm.py @@ -14,7 +14,10 @@ from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import Participant -from vision_agents.core.llm.events import LLMResponseChunkEvent, LLMResponseCompletedEvent +from vision_agents.core.llm.events import ( + LLMResponseChunkEvent, + LLMResponseCompletedEvent, +) from vision_agents.core.processors import Processor from . import events @@ -59,7 +62,9 @@ def __init__( super().__init__() self.events.register_events_from_module(events) self.model = model - self._pending_tool_uses_by_index: Dict[int, Dict[str, Any]] = {} # index -> {id, name, parts: []} + self._pending_tool_uses_by_index: Dict[ + int, Dict[str, Any] + ] = {} # index -> {id, name, parts: []} if client is not None: self.client = client @@ -107,7 +112,7 @@ async def create_message(self, *args, **kwargs) -> LLMResponseEvent[Any]: # ensure the AI remembers the past conversation new_messages = kwargs["messages"] - if hasattr(self, '_conversation') and self._conversation: + if hasattr(self, "_conversation") and self._conversation: old_messages = [m.original for m in self._conversation.messages] kwargs["messages"] = old_messages + new_messages # Add messages to conversation @@ -122,7 +127,7 @@ async def create_message(self, *args, **kwargs) -> LLMResponseEvent[Any]: # Extract text from Claude's response format - safely handle all text blocks text = self._concat_text_blocks(original.content) llm_response = LLMResponseEvent(original, text) - + # Multi-hop tool calling loop for non-streaming function_calls = self._extract_tool_calls_from_response(original) if function_calls: @@ -131,39 +136,50 @@ async def create_message(self, *args, **kwargs) -> LLMResponseEvent[Any]: rounds = 0 seen: set[tuple[str, str, str]] = set() current_calls = function_calls - + while current_calls and rounds < MAX_ROUNDS: # Execute calls concurrently with dedup - triples, seen = await self._dedup_and_execute(current_calls, seen=seen, max_concurrency=8, timeout_s=30) # type: ignore[arg-type] - + triples, seen = await self._dedup_and_execute( + current_calls, seen=seen, max_concurrency=8, timeout_s=30 + ) # type: ignore[arg-type] + if not triples: break - + # Build tool_result user message assistant_content = [] tool_result_blocks = [] for tc, res, err in triples: - assistant_content.append({ - "type": "tool_use", - "id": tc["id"], - "name": tc["name"], - "input": tc["arguments_json"], - }) - + assistant_content.append( + { + "type": "tool_use", + "id": tc["id"], + "name": tc["name"], + "input": tc["arguments_json"], + } + ) + payload = self._sanitize_tool_output(res) - tool_result_blocks.append({ - "type": "tool_result", - "tool_use_id": tc["id"], - "content": payload, - }) + tool_result_blocks.append( + { + "type": "tool_result", + "tool_use_id": tc["id"], + "content": payload, + } + ) assistant_msg = {"role": "assistant", "content": assistant_content} - user_tool_results_msg = {"role": "user", "content": tool_result_blocks} + user_tool_results_msg = { + "role": "user", + "content": tool_result_blocks, + } messages = messages + [assistant_msg, user_tool_results_msg] # Ask again WITH tools so Claude can do another hop tools_cfg = { - "tools": self._convert_tools_to_provider_format(self.get_available_functions()), + "tools": self._convert_tools_to_provider_format( + self.get_available_functions() + ), "tool_choice": {"type": "auto"}, "stream": False, "model": self.model, @@ -172,22 +188,29 @@ async def create_message(self, *args, **kwargs) -> LLMResponseEvent[Any]: } follow_up_response = await self.client.messages.create(**tools_cfg) - + # Extract new tool calls from follow-up response - current_calls = self._extract_tool_calls_from_response(follow_up_response) - llm_response = LLMResponseEvent(follow_up_response, self._concat_text_blocks(follow_up_response.content)) + current_calls = self._extract_tool_calls_from_response( + follow_up_response + ) + llm_response = LLMResponseEvent( + follow_up_response, + self._concat_text_blocks(follow_up_response.content), + ) rounds += 1 - + # Finalization pass: no tools so Claude must answer in text if current_calls or rounds > 0: # Only if we had tool calls final_response = await self.client.messages.create( model=self.model, - messages=messages, # includes assistant tool_use + user tool_result blocks + messages=messages, # includes assistant tool_use + user tool_result blocks stream=False, - max_tokens=1000 + max_tokens=1000, + ) + llm_response = LLMResponseEvent( + final_response, self._concat_text_blocks(final_response.content) ) - llm_response = LLMResponseEvent(final_response, self._concat_text_blocks(final_response.content)) - + elif isinstance(original, AsyncStream): stream: AsyncStream[RawMessageStreamEvent] = original text_parts: List[str] = [] @@ -195,7 +218,9 @@ async def create_message(self, *args, **kwargs) -> LLMResponseEvent[Any]: # 1) First round: read stream, gather initial tool_use calls async for event in stream: - llm_response_optional = self._standardize_and_emit_event(event, text_parts) + llm_response_optional = self._standardize_and_emit_event( + event, text_parts + ) if llm_response_optional is not None: llm_response = llm_response_optional # Collect tool_use calls as they complete (your helper already reconstructs args) @@ -213,7 +238,9 @@ async def create_message(self, *args, **kwargs) -> LLMResponseEvent[Any]: last_followup_stream = None while accumulated_calls and rounds < MAX_ROUNDS: # Execute calls concurrently with dedup - triples, seen = await self._dedup_and_execute(accumulated_calls, seen=seen, max_concurrency=8, timeout_s=30) # type: ignore[arg-type] + triples, seen = await self._dedup_and_execute( + accumulated_calls, seen=seen, max_concurrency=8, timeout_s=30 + ) # type: ignore[arg-type] # Build tool_result user message # Also reconstruct the assistant tool_use message that triggered these calls @@ -221,22 +248,26 @@ async def create_message(self, *args, **kwargs) -> LLMResponseEvent[Any]: executed_calls: List[NormalizedToolCallItem] = [] for tc, res, err in triples: executed_calls.append(tc) - assistant_content.append({ - "type": "tool_use", - "id": tc["id"], - "name": tc["name"], - "input": tc["arguments_json"], - }) + assistant_content.append( + { + "type": "tool_use", + "id": tc["id"], + "name": tc["name"], + "input": tc["arguments_json"], + } + ) # tool_result blocks (sanitize to keep payloads safe) tool_result_blocks = [] for tc, res, err in triples: payload = self._sanitize_tool_output(res) - tool_result_blocks.append({ - "type": "tool_result", - "tool_use_id": tc["id"], - "content": payload, - }) + tool_result_blocks.append( + { + "type": "tool_result", + "tool_use_id": tc["id"], + "content": payload, + } + ) assistant_msg = {"role": "assistant", "content": assistant_content} user_tool_results_msg = {"role": "user", "content": tool_result_blocks} @@ -244,7 +275,9 @@ async def create_message(self, *args, **kwargs) -> LLMResponseEvent[Any]: # Ask again WITH tools so Claude can do another hop tools_cfg = { - "tools": self._convert_tools_to_provider_format(self.get_available_functions()), + "tools": self._convert_tools_to_provider_format( + self.get_available_functions() + ), "tool_choice": {"type": "auto"}, "stream": True, "model": self.model, @@ -259,7 +292,9 @@ async def create_message(self, *args, **kwargs) -> LLMResponseEvent[Any]: accumulated_calls = [] # reset; we'll refill with new calls async for ev in follow_up_stream: last_followup_stream = ev - llm_response_optional = self._standardize_and_emit_event(ev, follow_up_text_parts) + llm_response_optional = self._standardize_and_emit_event( + ev, follow_up_text_parts + ) if llm_response_optional is not None: llm_response = llm_response_optional new_calls, _ = self._extract_tool_calls_from_stream_chunk(ev, None) @@ -276,14 +311,16 @@ async def create_message(self, *args, **kwargs) -> LLMResponseEvent[Any]: if accumulated_calls or rounds > 0: # Only if we had tool calls final_stream = await self.client.messages.create( model=self.model, - messages=messages, # includes assistant tool_use + user tool_result blocks + messages=messages, # includes assistant tool_use + user tool_result blocks stream=True, - max_tokens=1000 + max_tokens=1000, ) final_text_parts: List[str] = [] async for ev in final_stream: last_followup_stream = ev - llm_response_optional = self._standardize_and_emit_event(ev, final_text_parts) + llm_response_optional = self._standardize_and_emit_event( + ev, final_text_parts + ) if llm_response_optional is not None: llm_response = llm_response_optional if final_text_parts: @@ -291,8 +328,16 @@ async def create_message(self, *args, **kwargs) -> LLMResponseEvent[Any]: # 4) Done -> return all collected text total_text = "".join(text_parts) - llm_response = LLMResponseEvent(last_followup_stream or original, total_text) # type: ignore - self.events.send(LLMResponseCompletedEvent(original=last_followup_stream or original, text=total_text, plugin_name="anthropic")) + llm_response = LLMResponseEvent( + last_followup_stream or original, total_text + ) # type: ignore + self.events.send( + LLMResponseCompletedEvent( + original=last_followup_stream or original, + text=total_text, + plugin_name="anthropic", + ) + ) return llm_response @@ -303,10 +348,9 @@ def _standardize_and_emit_event( Forwards the events and also send out a standardized version (the agent class hooks into that) """ # forward the native event - self.events.send(events.ClaudeStreamEvent( - plugin_name="anthropic", - event_data=event - )) + self.events.send( + events.ClaudeStreamEvent(plugin_name="anthropic", event_data=event) + ) # send a standardized version for delta and response if event.type == "content_block_delta": @@ -314,14 +358,16 @@ def _standardize_and_emit_event( if hasattr(delta_event.delta, "text") and delta_event.delta.text: text_parts.append(delta_event.delta.text) - self.events.send(LLMResponseChunkEvent( - plugin_name="antrhopic", - content_index=delta_event.index, - item_id="", - output_index=0, - sequence_number=0, - delta=delta_event.delta.text, - )) + self.events.send( + LLMResponseChunkEvent( + plugin_name="antrhopic", + content_index=delta_event.index, + item_id="", + output_index=0, + sequence_number=0, + delta=delta_event.delta.text, + ) + ) elif event.type == "message_stop": stop_event: RawMessageStopEvent = event total_text = "".join(text_parts) @@ -354,13 +400,15 @@ def _normalize_message(claude_messages: Any) -> List["Message"]: return messages - def _convert_tools_to_provider_format(self, tools: List[ToolSchema]) -> List[Dict[str, Any]]: + def _convert_tools_to_provider_format( + self, tools: List[ToolSchema] + ) -> List[Dict[str, Any]]: """ Convert ToolSchema objects to Anthropic format. - + Args: tools: List of ToolSchema objects - + Returns: List of tools in Anthropic format """ @@ -369,37 +417,42 @@ def _convert_tools_to_provider_format(self, tools: List[ToolSchema]) -> List[Dic anthropic_tool = { "name": tool["name"], "description": tool.get("description", ""), - "input_schema": tool["parameters_schema"] + "input_schema": tool["parameters_schema"], } anthropic_tools.append(anthropic_tool) return anthropic_tools - def _extract_tool_calls_from_response(self, response: Any) -> List[NormalizedToolCallItem]: + def _extract_tool_calls_from_response( + self, response: Any + ) -> List[NormalizedToolCallItem]: """ Extract tool calls from Anthropic response. - + Args: response: Anthropic response object - + Returns: List of normalized tool call items """ tool_calls = [] - - if hasattr(response, 'content') and response.content: + + if hasattr(response, "content") and response.content: for content_block in response.content: - if hasattr(content_block, 'type') and content_block.type == "tool_use": + if hasattr(content_block, "type") and content_block.type == "tool_use": tool_call: NormalizedToolCallItem = { "type": "tool_call", "id": content_block.id, # Critical: capture the id for tool_result "name": content_block.name, - "arguments_json": content_block.input or {} # normalize to arguments_json + "arguments_json": content_block.input + or {}, # normalize to arguments_json } tool_calls.append(tool_call) - + return tool_calls - def _extract_tool_calls_from_stream_chunk(self, chunk: Any, current_tool_call: Optional[NormalizedToolCallItem] = None) -> tuple[List[NormalizedToolCallItem], Optional[NormalizedToolCallItem]]: # type: ignore[override] + def _extract_tool_calls_from_stream_chunk( + self, chunk: Any, current_tool_call: Optional[NormalizedToolCallItem] = None + ) -> tuple[List[NormalizedToolCallItem], Optional[NormalizedToolCallItem]]: # type: ignore[override] """ Extract tool calls from Anthropic streaming chunk using index-keyed accumulation. Args: @@ -409,22 +462,22 @@ def _extract_tool_calls_from_stream_chunk(self, chunk: Any, current_tool_call: O Tuple of (completed tool calls, current tool call being accumulated) """ tool_calls = [] - t = getattr(chunk, 'type', None) + t = getattr(chunk, "type", None) if t == "content_block_start": - cb = getattr(chunk, 'content_block', None) - if getattr(cb, 'type', None) == "tool_use": + cb = getattr(chunk, "content_block", None) + if getattr(cb, "type", None) == "tool_use": if cb is not None: self._pending_tool_uses_by_index[chunk.index] = { "id": cb.id, "name": cb.name, - "parts": [] + "parts": [], } elif t == "content_block_delta": - d = getattr(chunk, 'delta', None) - if getattr(d, 'type', None) == "input_json_delta": - pj = getattr(d, 'partial_json', None) + d = getattr(chunk, "delta", None) + if getattr(d, "type", None) == "input_json_delta": + pj = getattr(d, "partial_json", None) if pj is not None and chunk.index in self._pending_tool_uses_by_index: self._pending_tool_uses_by_index[chunk.index]["parts"].append(pj) @@ -440,12 +493,14 @@ def _extract_tool_calls_from_stream_chunk(self, chunk: Any, current_tool_call: O "type": "tool_call", "id": pending["id"], "name": pending["name"], - "arguments_json": args + "arguments_json": args, } tool_calls.append(tool_call_item) return tool_calls, None - def _create_tool_result_message(self, tool_calls: List[NormalizedToolCallItem], results: List[Any]) -> List[Dict[str, Any]]: + def _create_tool_result_message( + self, tool_calls: List[NormalizedToolCallItem], results: List[Any] + ) -> List[Dict[str, Any]]: """ Create tool result messages for Anthropic. tool_calls: List of tool calls that were executed @@ -461,17 +516,19 @@ def _create_tool_result_message(self, tool_calls: List[NormalizedToolCallItem], payload = str(result) else: payload = json.dumps(result) - blocks.append({ - "type": "tool_result", - "tool_use_id": tool_call["id"], # Critical: must match tool_use.id - "content": payload - }) + blocks.append( + { + "type": "tool_result", + "tool_use_id": tool_call["id"], # Critical: must match tool_use.id + "content": payload, + } + ) return [{"role": "user", "content": blocks}] def _concat_text_blocks(self, content): """Safely extract text from all text blocks in content.""" out = [] for b in content or []: - if getattr(b, 'type', None) == "text" and getattr(b, 'text', None): + if getattr(b, "type", None) == "text" and getattr(b, "text", None): out.append(b.text) return "".join(out) diff --git a/plugins/anthropic/vision_agents/plugins/anthropic/events.py b/plugins/anthropic/vision_agents/plugins/anthropic/events.py index 78a1f53b..188404e0 100644 --- a/plugins/anthropic/vision_agents/plugins/anthropic/events.py +++ b/plugins/anthropic/vision_agents/plugins/anthropic/events.py @@ -6,5 +6,6 @@ @dataclass class ClaudeStreamEvent(PluginBaseEvent): """Event emitted when Claude provides a stream event.""" - type: str = field(default='plugin.anthropic.claude_stream', init=False) + + type: str = field(default="plugin.anthropic.claude_stream", init=False) event_data: Optional[Any] = None diff --git a/plugins/deepgram/tests/test_realtime.py b/plugins/deepgram/tests/test_realtime.py index f5a7f6d0..269d1014 100644 --- a/plugins/deepgram/tests/test_realtime.py +++ b/plugins/deepgram/tests/test_realtime.py @@ -4,7 +4,11 @@ from unittest.mock import patch, MagicMock from vision_agents.plugins import deepgram -from vision_agents.core.stt.events import STTTranscriptEvent, STTPartialTranscriptEvent, STTErrorEvent +from vision_agents.core.stt.events import ( + STTTranscriptEvent, + STTPartialTranscriptEvent, + STTErrorEvent, +) from getstream.video.rtc.track_util import PcmData diff --git a/plugins/deepgram/tests/test_stt.py b/plugins/deepgram/tests/test_stt.py index cc2f7478..4450ab03 100644 --- a/plugins/deepgram/tests/test_stt.py +++ b/plugins/deepgram/tests/test_stt.py @@ -6,7 +6,11 @@ import os from vision_agents.plugins import deepgram -from vision_agents.core.stt.events import STTTranscriptEvent, STTPartialTranscriptEvent, STTErrorEvent +from vision_agents.core.stt.events import ( + STTTranscriptEvent, + STTPartialTranscriptEvent, + STTErrorEvent, +) from getstream.video.rtc.track_util import PcmData from plugins.plugin_test_utils import get_audio_asset, get_json_metadata diff --git a/plugins/deepgram/vision_agents/plugins/deepgram/__init__.py b/plugins/deepgram/vision_agents/plugins/deepgram/__init__.py index 07c829b2..3df699a0 100644 --- a/plugins/deepgram/vision_agents/plugins/deepgram/__init__.py +++ b/plugins/deepgram/vision_agents/plugins/deepgram/__init__.py @@ -4,4 +4,3 @@ __path__ = __import__("pkgutil").extend_path(__path__, __name__) __all__ = ["STT"] - diff --git a/plugins/deepgram/vision_agents/plugins/deepgram/stt.py b/plugins/deepgram/vision_agents/plugins/deepgram/stt.py index 3faa5350..35607811 100644 --- a/plugins/deepgram/vision_agents/plugins/deepgram/stt.py +++ b/plugins/deepgram/vision_agents/plugins/deepgram/stt.py @@ -8,7 +8,12 @@ import os import time -from deepgram import DeepgramClient, LiveTranscriptionEvents, LiveOptions, DeepgramClientOptions +from deepgram import ( + DeepgramClient, + LiveTranscriptionEvents, + LiveOptions, + DeepgramClientOptions, +) from vision_agents.core import stt from getstream.video.rtc.track_util import PcmData @@ -65,9 +70,13 @@ def __init__( # Initialize DeepgramClient with the API key logger.info("Initializing Deepgram client") config = DeepgramClientOptions( - options={"keepalive": "true"} # Comment this out to see the effect of not using keepalive + options={ + "keepalive": "true" + } # Comment this out to see the effect of not using keepalive + ) + self.deepgram = ( + client if client is not None else DeepgramClient(api_key, config) ) - self.deepgram = client if client is not None else DeepgramClient(api_key, config) self.dg_connection: Optional[Any] = None self.options = options or LiveOptions( model="nova-2", @@ -204,7 +213,9 @@ def handle_error(conn, error=None): self._emit_error_event(e, "Deepgram connection setup") async def _process_audio_impl( - self, pcm_data: PcmData, user_metadata: Optional[Union[Dict[str, Any], "Participant"]] = None + self, + pcm_data: PcmData, + user_metadata: Optional[Union[Dict[str, Any], "Participant"]] = None, ) -> Optional[List[Tuple[bool, str, Dict[str, Any]]]]: """ Process audio data through Deepgram for transcription. diff --git a/plugins/elevenlabs/vision_agents/plugins/elevenlabs/__init__.py b/plugins/elevenlabs/vision_agents/plugins/elevenlabs/__init__.py index 88218881..4aaee0ff 100644 --- a/plugins/elevenlabs/vision_agents/plugins/elevenlabs/__init__.py +++ b/plugins/elevenlabs/vision_agents/plugins/elevenlabs/__init__.py @@ -4,4 +4,3 @@ __path__ = __import__("pkgutil").extend_path(__path__, __name__) __all__ = ["TTS"] - diff --git a/plugins/gemini/tests/test_gemini_llm.py b/plugins/gemini/tests/test_gemini_llm.py index 093c3969..ba98ab40 100644 --- a/plugins/gemini/tests/test_gemini_llm.py +++ b/plugins/gemini/tests/test_gemini_llm.py @@ -10,9 +10,7 @@ load_dotenv() - class TestGeminiLLM: - def test_message(self): messages = GeminiLLM._normalize_message("say hi") assert isinstance(messages[0], Message) @@ -47,14 +45,14 @@ async def test_native_api(self, llm: GeminiLLM): @pytest.mark.integration async def test_stream(self, llm: GeminiLLM): streamingWorks = False - + @llm.events.subscribe async def passed(event: LLMResponseChunkEvent): nonlocal streamingWorks streamingWorks = True - + await llm.simple_response("Explain magma to a 5 year old") - + # Wait for all events in queue to be processed await llm.events.wait() @@ -63,7 +61,9 @@ async def passed(event: LLMResponseChunkEvent): @pytest.mark.integration async def test_memory(self, llm: GeminiLLM): await llm.simple_response(text="There are 2 dogs in the room") - response = await llm.simple_response(text="How many paws are there in the room?") + response = await llm.simple_response( + text="How many paws are there in the room?" + ) assert "8" in response.text or "eight" in response.text diff --git a/plugins/gemini/tests/test_gemini_realtime.py b/plugins/gemini/tests/test_gemini_realtime.py index 355befff..152ac3d2 100644 --- a/plugins/gemini/tests/test_gemini_realtime.py +++ b/plugins/gemini/tests/test_gemini_realtime.py @@ -29,11 +29,11 @@ async def test_simple_response_flow(self, realtime): """Test sending a simple text message and receiving response""" # Send a simple message events = [] - + @realtime.events.subscribe async def on_audio(event: RealtimeAudioOutputEvent): events.append(event) - + await asyncio.sleep(0.01) await realtime.connect() await realtime.simple_response("Hello, can you hear me?") @@ -46,15 +46,17 @@ async def on_audio(event: RealtimeAudioOutputEvent): async def test_audio_sending_flow(self, realtime, mia_audio_16khz): """Test sending real audio data and verify connection remains stable""" events = [] - + @realtime.events.subscribe async def on_audio(event: RealtimeAudioOutputEvent): events.append(event) - + await asyncio.sleep(0.01) await realtime.connect() - - await realtime.simple_response("Listen to the following story, what is Mia looking for?") + + await realtime.simple_response( + "Listen to the following story, what is Mia looking for?" + ) await asyncio.sleep(10.0) await realtime.simple_audio_response(mia_audio_16khz) @@ -62,26 +64,25 @@ async def on_audio(event: RealtimeAudioOutputEvent): await asyncio.sleep(10.0) assert len(events) > 0 - @pytest.mark.integration async def test_video_sending_flow(self, realtime, bunny_video_track): """Test sending real video data and verify connection remains stable""" events = [] - + @realtime.events.subscribe async def on_audio(event: RealtimeAudioOutputEvent): events.append(event) - + await asyncio.sleep(0.01) await realtime.connect() await realtime.simple_response("Describe what you see in this video please") await asyncio.sleep(10.0) # Start video sender with low FPS to avoid overwhelming the connection await realtime._watch_video_track(bunny_video_track) - + # Let it run for a few seconds await asyncio.sleep(10.0) - + # Stop video sender await realtime._stop_watching_video_track() assert len(events) > 0 @@ -91,13 +92,14 @@ async def test_frame_to_png_bytes_with_bunny_video(self, bunny_video_track): # Get a frame from the bunny video track frame = await bunny_video_track.recv() png_bytes = frame_to_png_bytes(frame) - + # Verify we got PNG data assert isinstance(png_bytes, bytes) assert len(png_bytes) > 0 - + # Verify it's actually PNG data (PNG files start with specific bytes) - assert png_bytes.startswith(b'\x89PNG\r\n\x1a\n') - - print(f"Successfully converted bunny video frame to PNG: {len(png_bytes)} bytes") + assert png_bytes.startswith(b"\x89PNG\r\n\x1a\n") + print( + f"Successfully converted bunny video frame to PNG: {len(png_bytes)} bytes" + ) diff --git a/plugins/gemini/tests/test_realtime_function_calling.py b/plugins/gemini/tests/test_realtime_function_calling.py index 24d2848c..54b02ec4 100644 --- a/plugins/gemini/tests/test_realtime_function_calling.py +++ b/plugins/gemini/tests/test_realtime_function_calling.py @@ -9,7 +9,10 @@ from dotenv import load_dotenv from vision_agents.plugins import gemini -from vision_agents.core.llm.events import RealtimeResponseEvent, RealtimeAudioOutputEvent +from vision_agents.core.llm.events import ( + RealtimeResponseEvent, + RealtimeAudioOutputEvent, +) # Load environment variables load_dotenv() @@ -23,7 +26,9 @@ async def realtime_instance(self): """Create a realtime instance with real Gemini client.""" api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") if not api_key: - pytest.skip("GOOGLE_API_KEY or GEMINI_API_KEY not set - skipping integration test") + pytest.skip( + "GOOGLE_API_KEY or GEMINI_API_KEY not set - skipping integration test" + ) # Check if google.genai is available try: @@ -32,10 +37,9 @@ async def realtime_instance(self): pytest.skip(f"Required Google packages not available: {e}") realtime = gemini.Realtime( - model="gemini-2.5-flash-native-audio-preview-09-2025", - api_key=api_key + model="gemini-2.5-flash-native-audio-preview-09-2025", api_key=api_key ) - + try: yield realtime finally: @@ -46,7 +50,7 @@ async def test_convert_tools_to_provider_format(self): """Test tool conversion to Gemini Live format.""" # Create a minimal instance just for testing the conversion method realtime = gemini.Realtime(model="test-model", api_key="test-key") - + # Test tools tools = [ { @@ -57,8 +61,8 @@ async def test_convert_tools_to_provider_format(self): "properties": { "location": {"type": "string", "description": "City name"} }, - "required": ["location"] - } + "required": ["location"], + }, }, { "name": "calculate", @@ -66,25 +70,28 @@ async def test_convert_tools_to_provider_format(self): "parameters_schema": { "type": "object", "properties": { - "expression": {"type": "string", "description": "Math expression"} + "expression": { + "type": "string", + "description": "Math expression", + } }, - "required": ["expression"] - } - } + "required": ["expression"], + }, + }, ] - + result = realtime._convert_tools_to_provider_format(tools) - + assert len(result) == 1 assert "function_declarations" in result[0] assert len(result[0]["function_declarations"]) == 2 - + # Check first tool tool1 = result[0]["function_declarations"][0] assert tool1["name"] == "get_weather" assert tool1["description"] == "Get weather information" assert "location" in tool1["parameters"]["properties"] - + # Check second tool tool2 = result[0]["function_declarations"][1] assert tool2["name"] == "calculate" @@ -96,11 +103,11 @@ async def test_convert_tools_to_provider_format(self): async def test_live_function_calling_basic(self, realtime_instance): """Test basic live function calling with weather function.""" realtime = realtime_instance - + # Track function calls and responses function_calls: List[Dict[str, Any]] = [] text_responses: List[str] = [] - + # Register a weather function @realtime.register_function(description="Get current weather for a location") def get_weather(location: str) -> Dict[str, str]: @@ -110,9 +117,9 @@ def get_weather(location: str) -> Dict[str, str]: "location": location, "temperature": "22°C", "condition": "Sunny", - "humidity": "65%" + "humidity": "65%", } - + # Set up event listeners for audio output @realtime.events.subscribe async def handle_audio_output(event: RealtimeAudioOutputEvent): @@ -124,22 +131,22 @@ async def handle_audio_output(event: RealtimeAudioOutputEvent): async def handle_response(event: RealtimeResponseEvent): if event.text: text_responses.append(event.text) - + # Connect and send a prompt that should trigger the function await realtime.connect() - + # Send a prompt that encourages function calling prompt = "What's the weather like in New York? Please use the get_weather function to check." await realtime.simple_response(prompt) - + # Wait for response and function call await asyncio.sleep(8.0) - + # Verify function was called assert len(function_calls) > 0, "Function was not called by Gemini" assert function_calls[0]["name"] == "get_weather" assert function_calls[0]["location"] == "New York" - + # Remove the text response assertion @pytest.mark.integration @@ -147,11 +154,11 @@ async def handle_response(event: RealtimeResponseEvent): async def test_live_function_calling_error_handling(self, realtime_instance): """Test live function calling with error handling.""" realtime = realtime_instance - + # Track function calls and responses function_calls: List[Dict[str, Any]] = [] text_responses: List[str] = [] - + # Register a function that will raise an error @realtime.register_function(description="A function that sometimes fails") def unreliable_function(input_data: str) -> Dict[str, Any]: @@ -160,33 +167,33 @@ def unreliable_function(input_data: str) -> Dict[str, Any]: if "error" in input_data.lower(): raise ValueError("Simulated error for testing") return {"result": f"Success: {input_data}"} - + # Set up event listeners for audio output @realtime.events.subscribe async def handle_audio_output(event: RealtimeAudioOutputEvent): if event.audio_data: # Audio was received - this indicates Gemini responded text_responses.append("audio_response_received") - + @realtime.events.subscribe async def handle_response(event: RealtimeResponseEvent): if event.text: text_responses.append(event.text) - + # Connect and send a prompt that should trigger the function with error await realtime.connect() - + # Send a prompt that should cause an error prompt = "Please call the unreliable_function with 'error test' as input." await realtime.simple_response(prompt) - + # Wait for response and function call await asyncio.sleep(8.0) - + # Verify function was called assert len(function_calls) > 0, "Function was not called by Gemini" assert function_calls[0]["name"] == "unreliable_function" - + # Verify we got a response (should mention the error) assert len(text_responses) > 0, "No response received from Gemini" @@ -195,53 +202,55 @@ async def handle_response(event: RealtimeResponseEvent): async def test_live_function_calling_multiple_functions(self, realtime_instance): """Test live function calling with multiple functions in one request.""" realtime = realtime_instance - + # Track function calls function_calls: List[Dict[str, Any]] = [] text_responses: List[str] = [] - + # Register multiple functions @realtime.register_function(description="Get current time") def get_time() -> Dict[str, str]: """Get current time.""" function_calls.append({"name": "get_time"}) return {"time": "2024-01-15 14:30:00", "timezone": "UTC"} - + @realtime.register_function(description="Get system status") def get_status() -> Dict[str, str]: """Get system status.""" function_calls.append({"name": "get_status"}) return {"status": "healthy", "uptime": "24h"} - + # Set up event listeners for audio output @realtime.events.subscribe async def handle_audio_output(event: RealtimeAudioOutputEvent): if event.audio_data: # Audio was received - this indicates Gemini responded text_responses.append("audio_response_received") - + @realtime.events.subscribe async def handle_response(event: RealtimeResponseEvent): if event.text: text_responses.append(event.text) - + # Connect and send a prompt that should trigger multiple functions await realtime.connect() - + # Send a prompt that encourages multiple function calls prompt = "Please check the current time and system status using the available functions." await realtime.simple_response(prompt) - + # Wait for response and function calls await asyncio.sleep(10.0) - + # Verify functions were called - assert len(function_calls) >= 2, f"Expected at least 2 function calls, got {len(function_calls)}" - + assert len(function_calls) >= 2, ( + f"Expected at least 2 function calls, got {len(function_calls)}" + ) + function_names = [call["name"] for call in function_calls] assert "get_time" in function_names, "get_time function was not called" assert "get_status" in function_names, "get_status function was not called" - + # Verify we got a response assert len(text_responses) > 0, "No response received from Gemini" @@ -250,14 +259,14 @@ async def test_create_config_with_tools(self): """Test that tools are added to the config.""" # Create a minimal instance for testing config creation realtime = gemini.Realtime(model="test-model", api_key="test-key") - + # Register a test function @realtime.register_function(description="Test function") def test_func(param: str) -> str: return f"test: {param}" - + config = realtime._get_config_with_resumption() - + # Verify tools were added assert "tools" in config assert len(config["tools"]) == 1 @@ -270,8 +279,8 @@ async def test_create_config_without_tools(self): """Test config creation when no tools are available.""" # Create a minimal instance without registering any functions realtime = gemini.Realtime(model="test-model", api_key="test-key") - + config = realtime._create_config() - + # Verify tools were not added assert "tools" not in config diff --git a/plugins/gemini/vision_agents/plugins/gemini/events.py b/plugins/gemini/vision_agents/plugins/gemini/events.py index c05ff05c..66e5050d 100644 --- a/plugins/gemini/vision_agents/plugins/gemini/events.py +++ b/plugins/gemini/vision_agents/plugins/gemini/events.py @@ -6,33 +6,38 @@ @dataclass class GeminiConnectedEvent(PluginBaseEvent): """Event emitted when Gemini realtime connection is established.""" - type: str = field(default='plugin.gemini.connected', init=False) + + type: str = field(default="plugin.gemini.connected", init=False) model: Optional[str] = None @dataclass class GeminiErrorEvent(PluginBaseEvent): """Event emitted when Gemini encounters an error.""" - type: str = field(default='plugin.gemini.error', init=False) + + type: str = field(default="plugin.gemini.error", init=False) error: Optional[Any] = None @dataclass class GeminiAudioEvent(PluginBaseEvent): """Event emitted when Gemini provides audio output.""" - type: str = field(default='plugin.gemini.audio', init=False) + + type: str = field(default="plugin.gemini.audio", init=False) audio_data: Optional[bytes] = None @dataclass class GeminiTextEvent(PluginBaseEvent): """Event emitted when Gemini provides text output.""" - type: str = field(default='plugin.gemini.text', init=False) + + type: str = field(default="plugin.gemini.text", init=False) text: Optional[str] = None @dataclass class GeminiResponseEvent(PluginBaseEvent): """Event emitted when Gemini provides a response chunk.""" - type: str = field(default='plugin.gemini.response', init=False) + + type: str = field(default="plugin.gemini.response", init=False) response_chunk: Optional[Any] = None diff --git a/plugins/gemini/vision_agents/plugins/gemini/gemini_llm.py b/plugins/gemini/vision_agents/plugins/gemini/gemini_llm.py index 2994e3ae..06d8b4ef 100644 --- a/plugins/gemini/vision_agents/plugins/gemini/gemini_llm.py +++ b/plugins/gemini/vision_agents/plugins/gemini/gemini_llm.py @@ -7,7 +7,10 @@ from vision_agents.core.llm.llm import LLM, LLMResponseEvent from vision_agents.core.llm.llm_types import ToolSchema, NormalizedToolCallItem -from vision_agents.core.llm.events import LLMResponseCompletedEvent, LLMResponseChunkEvent +from vision_agents.core.llm.events import ( + LLMResponseCompletedEvent, + LLMResponseChunkEvent, +) from . import events @@ -19,24 +22,30 @@ class GeminiLLM(LLM): """ - The GeminiLLM class provides full/native access to the gemini SDK methods. - It only standardized the minimal feature set that's needed for the agent integration. + The GeminiLLM class provides full/native access to the gemini SDK methods. + It only standardized the minimal feature set that's needed for the agent integration. - The agent requires that we standardize: - - sharing instructions - - keeping conversation history - - response normalization + The agent requires that we standardize: + - sharing instructions + - keeping conversation history + - response normalization - Notes on the Gemini integration: - - the native method is called send_message (maps 1-1 to chat.send_message_stream) - - history is maintained in the gemini sdk (with the usage of client.chats.create(model=self.model)) + Notes on the Gemini integration: + - the native method is called send_message (maps 1-1 to chat.send_message_stream) + - history is maintained in the gemini sdk (with the usage of client.chats.create(model=self.model)) - Examples: + Examples: - from vision_agents.plugins import gemini - llm = gemini.LLM() - """ - def __init__(self, model: str, api_key: Optional[str] = None, client: Optional[genai.Client] = None): + from vision_agents.plugins import gemini + llm = gemini.LLM() + """ + + def __init__( + self, + model: str, + api_key: Optional[str] = None, + client: Optional[genai.Client] = None, + ): """ Initialize the GeminiLLM class. @@ -55,7 +64,12 @@ def __init__(self, model: str, api_key: Optional[str] = None, client: Optional[g else: self.client = genai.Client(api_key=api_key) - async def simple_response(self, text: str, processors: Optional[List[Processor]] = None, participant: Optional[Any] = None) -> LLMResponseEvent[Any]: + async def simple_response( + self, + text: str, + processors: Optional[List[Processor]] = None, + participant: Optional[Any] = None, + ) -> LLMResponseEvent[Any]: """ simple_response is a standardized way (across openai, claude, gemini etc.) to create a response. @@ -67,9 +81,7 @@ async def simple_response(self, text: str, processors: Optional[List[Processor]] llm.simple_response("say hi to the user, be mean") """ - return await self.send_message( - message=text - ) + return await self.send_message(message=text) async def send_message(self, *args, **kwargs): """ @@ -77,7 +89,7 @@ async def send_message(self, *args, **kwargs): under the hood it calls chat.send_message_stream(*args, **kwargs) this method wraps and ensures we broadcast an event which the agent class hooks into """ - #if "model" not in kwargs: + # if "model" not in kwargs: # kwargs["model"] = self.model # initialize chat if needed @@ -88,6 +100,7 @@ async def send_message(self, *args, **kwargs): tools_spec = self.get_available_functions() if tools_spec: from google.genai import types + conv_tools = self._convert_tools_to_provider_format(tools_spec) cfg = kwargs.get("config") if not isinstance(cfg, types.GenerateContentConfig): @@ -97,14 +110,16 @@ async def send_message(self, *args, **kwargs): # Generate content using the client iterator = self.chat.send_message_stream(*args, **kwargs) - text_parts : List[str] = [] + text_parts: List[str] = [] final_chunk = None pending_calls: List[NormalizedToolCallItem] = [] - + for chunk in iterator: response_chunk: GenerateContentResponse = chunk final_chunk = response_chunk - llm_response_optional = self._standardize_and_emit_event(response_chunk, text_parts) + llm_response_optional = self._standardize_and_emit_event( + response_chunk, text_parts + ) # collect function calls as they stream try: chunk_calls = self._extract_tool_calls_from_stream_chunk(chunk) @@ -121,12 +136,14 @@ async def send_message(self, *args, **kwargs): rounds = 0 current_calls = pending_calls cfg_with_tools = kwargs.get("config") - + seen: set[str] = set() while current_calls and rounds < MAX_ROUNDS: # Execute tools concurrently with deduplication - triples, seen = await self._dedup_and_execute(current_calls, max_concurrency=8, timeout_s=30, seen=seen) # type: ignore[arg-type] - + triples, seen = await self._dedup_and_execute( + current_calls, max_concurrency=8, timeout_s=30, seen=seen + ) # type: ignore[arg-type] + executed = [] parts = [] for tc, res, err in triples: @@ -138,28 +155,36 @@ async def send_message(self, *args, **kwargs): sanitized_res = {} for k, v in res.items(): sanitized_res[k] = self._sanitize_tool_output(v) - parts.append(types.Part.from_function_response(name=tc["name"], response=sanitized_res)) - + parts.append( + types.Part.from_function_response( + name=tc["name"], response=sanitized_res + ) + ) + # Send function responses with tools config - follow_up_iter = self.chat.send_message_stream(parts, config=cfg_with_tools) # type: ignore[arg-type] - + follow_up_iter = self.chat.send_message_stream( + parts, config=cfg_with_tools + ) # type: ignore[arg-type] + follow_up_text_parts: List[str] = [] follow_up_last = None next_calls = [] - + for chk in follow_up_iter: follow_up_last = chk - llm_response_optional = self._standardize_and_emit_event(chk, follow_up_text_parts) + llm_response_optional = self._standardize_and_emit_event( + chk, follow_up_text_parts + ) if llm_response_optional is not None: llm_response = llm_response_optional - + # Check for new function calls try: chunk_calls = self._extract_tool_calls_from_stream_chunk(chk) next_calls.extend(chunk_calls) except Exception: pass - + current_calls = next_calls rounds += 1 @@ -169,11 +194,13 @@ async def send_message(self, *args, **kwargs): total_text = "".join(text_parts) llm_response = LLMResponseEvent(final_chunk, total_text) - self.events.send(LLMResponseCompletedEvent( - plugin_name="gemini", - original=llm_response.original, - text=llm_response.text - )) + self.events.send( + LLMResponseCompletedEvent( + plugin_name="gemini", + original=llm_response.original, + text=llm_response.text, + ) + ) # Return the LLM response return llm_response @@ -181,12 +208,10 @@ async def send_message(self, *args, **kwargs): @staticmethod def _normalize_message(gemini_input) -> List["Message"]: from vision_agents.core.agents.conversation import Message - + # standardize on input if isinstance(gemini_input, str): - gemini_input = [ - gemini_input - ] + gemini_input = [gemini_input] if not isinstance(gemini_input, List): gemini_input = [gemini_input] @@ -198,31 +223,36 @@ def _normalize_message(gemini_input) -> List["Message"]: return messages - def _standardize_and_emit_event(self, chunk: GenerateContentResponse, text_parts: List[str]) -> Optional[LLMResponseEvent[Any]]: + def _standardize_and_emit_event( + self, chunk: GenerateContentResponse, text_parts: List[str] + ) -> Optional[LLMResponseEvent[Any]]: """ Forwards the events and also send out a standardized version (the agent class hooks into that) """ # forward the native event - self.events.send(events.GeminiResponseEvent( - plugin_name="gemini", - response_chunk=chunk - )) + self.events.send( + events.GeminiResponseEvent(plugin_name="gemini", response_chunk=chunk) + ) # Check if response has text content - if hasattr(chunk, 'text') and chunk.text: - self.events.send(LLMResponseChunkEvent( - plugin_name="gemini", - content_index=0, - item_id="", - output_index=0, - sequence_number=0, - delta=chunk.text, - )) + if hasattr(chunk, "text") and chunk.text: + self.events.send( + LLMResponseChunkEvent( + plugin_name="gemini", + content_index=0, + item_id="", + output_index=0, + sequence_number=0, + delta=chunk.text, + ) + ) text_parts.append(chunk.text) return None - def _convert_tools_to_provider_format(self, tools: List[ToolSchema]) -> List[Dict[str, Any]]: + def _convert_tools_to_provider_format( + self, tools: List[ToolSchema] + ) -> List[Dict[str, Any]]: """ Convert ToolSchema objects to Gemini format. Args: @@ -232,75 +262,93 @@ def _convert_tools_to_provider_format(self, tools: List[ToolSchema]) -> List[Dic """ function_declarations = [] for tool in tools: - function_declarations.append({ - "name": tool["name"], - "description": tool.get("description", ""), - "parameters": tool["parameters_schema"] - }) - + function_declarations.append( + { + "name": tool["name"], + "description": tool.get("description", ""), + "parameters": tool["parameters_schema"], + } + ) + # Return as dict with function_declarations (SDK accepts dicts) return [{"function_declarations": function_declarations}] - def _extract_tool_calls_from_response(self, response: Any) -> List[NormalizedToolCallItem]: + def _extract_tool_calls_from_response( + self, response: Any + ) -> List[NormalizedToolCallItem]: """ Extract tool calls from Gemini response. - + Args: response: Gemini response object - + Returns: List of normalized tool call items """ calls: List[NormalizedToolCallItem] = [] - + try: # Prefer the top-level convenience list if available function_calls = getattr(response, "function_calls", []) or [] for fc in function_calls: - calls.append({ - "type": "tool_call", - "name": getattr(fc, "name", "unknown"), - "arguments_json": getattr(fc, "args", {}) - }) + calls.append( + { + "type": "tool_call", + "name": getattr(fc, "name", "unknown"), + "arguments_json": getattr(fc, "args", {}), + } + ) if not calls and getattr(response, "candidates", None): for c in response.candidates: if getattr(c, "content", None): for part in c.content.parts: if getattr(part, "function_call", None): - calls.append({ - "type": "tool_call", - "name": getattr(part.function_call, "name", "unknown"), - "arguments_json": getattr(part.function_call, "args", {}), - }) + calls.append( + { + "type": "tool_call", + "name": getattr( + part.function_call, "name", "unknown" + ), + "arguments_json": getattr( + part.function_call, "args", {} + ), + } + ) except Exception: pass # Ignore extraction errors - + return calls - def _extract_tool_calls_from_stream_chunk(self, chunk: Any) -> List[NormalizedToolCallItem]: + def _extract_tool_calls_from_stream_chunk( + self, chunk: Any + ) -> List[NormalizedToolCallItem]: """ Extract tool calls from Gemini streaming chunk. - + Args: chunk: Gemini streaming event - + Returns: List of normalized tool call items """ try: - return self._extract_tool_calls_from_response(chunk) # chunks use same shape + return self._extract_tool_calls_from_response( + chunk + ) # chunks use same shape except Exception: return [] # Ignore extraction errors - def _create_tool_result_parts(self, tool_calls: List[NormalizedToolCallItem], results: List[Any]): + def _create_tool_result_parts( + self, tool_calls: List[NormalizedToolCallItem], results: List[Any] + ): """ Create function_response parts for Gemini. - + Args: tool_calls: List of tool calls that were executed results: List of results from function execution - + Returns: List of function_response parts """ @@ -312,9 +360,13 @@ def _create_tool_result_parts(self, tool_calls: List[NormalizedToolCallItem], re response_data = res else: response_data = {"result": res} - + # res may be dict/list/str; pass directly; SDK serializes - parts.append(types.Part.from_function_response(name=tc["name"], response=response_data)) + parts.append( + types.Part.from_function_response( + name=tc["name"], response=response_data + ) + ) except Exception: # Fallback: create a simple text part parts.append(types.Part(text=f"Function {tc['name']} returned: {res}")) diff --git a/plugins/gemini/vision_agents/plugins/gemini/gemini_realtime.py b/plugins/gemini/vision_agents/plugins/gemini/gemini_realtime.py index 651a06c4..35a68f2d 100644 --- a/plugins/gemini/vision_agents/plugins/gemini/gemini_realtime.py +++ b/plugins/gemini/vision_agents/plugins/gemini/gemini_realtime.py @@ -6,13 +6,29 @@ from google import genai from google.genai.live import AsyncSession from google.genai.types import SessionResumptionConfigDict -from google.genai.types import LiveConnectConfigDict, Modality, SpeechConfigDict, VoiceConfigDict, \ - PrebuiltVoiceConfigDict, AudioTranscriptionConfigDict, RealtimeInputConfigDict, TurnCoverage, \ - ContextWindowCompressionConfigDict, SlidingWindowDict, HttpOptions, LiveServerMessage, Blob, Part +from google.genai.types import ( + LiveConnectConfigDict, + Modality, + SpeechConfigDict, + VoiceConfigDict, + PrebuiltVoiceConfigDict, + AudioTranscriptionConfigDict, + RealtimeInputConfigDict, + TurnCoverage, + ContextWindowCompressionConfigDict, + SlidingWindowDict, + HttpOptions, + LiveServerMessage, + Blob, + Part, +) from vision_agents.core.edge.types import Participant from vision_agents.core.llm import realtime -from vision_agents.core.llm.events import RealtimeAudioOutputEvent, LLMResponseChunkEvent +from vision_agents.core.llm.events import ( + RealtimeAudioOutputEvent, + LLMResponseChunkEvent, +) from vision_agents.core.llm.llm_types import ToolSchema, NormalizedToolCallItem from vision_agents.core.processors import Processor from vision_agents.core.utils.utils import frame_to_png_bytes @@ -56,18 +72,27 @@ class Realtime(realtime.Realtime): - Audio output always uses a sample rate of 24kHz. - Input audio is natively 16kHz, but the Live API will resample if needed """ - model : str + + model: str session_resumption_id: Optional[str] = None config: LiveConnectConfigDict - connected : bool = False - - def __init__(self, model: str=DEFAULT_MODEL, config: Optional[LiveConnectConfigDict]=None, http_options: Optional[HttpOptions] = None, client: Optional[genai.Client] = None, api_key: Optional[str] = None , **kwargs) -> None: + connected: bool = False + + def __init__( + self, + model: str = DEFAULT_MODEL, + config: Optional[LiveConnectConfigDict] = None, + http_options: Optional[HttpOptions] = None, + client: Optional[genai.Client] = None, + api_key: Optional[str] = None, + **kwargs, + ) -> None: super().__init__(**kwargs) self.model = model if http_options is None: http_options = HttpOptions(api_version="v1alpha") - if client is None: + if client is None: if api_key: client = genai.Client(api_key=api_key, http_options=http_options) else: @@ -85,8 +110,12 @@ def __init__(self, model: str=DEFAULT_MODEL, config: Optional[LiveConnectConfigD self._session: Optional[AsyncSession] = None self._receive_task: Optional[asyncio.Task[Any]] = None - async def simple_response(self, text: str, processors: Optional[List[Processor]] = None, - participant: Optional[Participant] = None): + async def simple_response( + self, + text: str, + processors: Optional[List[Processor]] = None, + participant: Optional[Participant] = None, + ): """ Simple response standardizes how to send a text instruction to this LLM. @@ -98,7 +127,6 @@ async def simple_response(self, text: str, processors: Optional[List[Processor]] self.logger.info("Simple response called with text: %s", text) await self.send_realtime_input(text=text) - async def simple_audio_response(self, pcm: PcmData): """ Simple audio response standardizes how to send audio to the LLM @@ -125,9 +153,7 @@ async def send_realtime_input(self, *args, **kwargs): send_realtime_input wraps the native send_realtime_input """ try: - await self._require_session().send_realtime_input( - *args, **kwargs - ) + await self._require_session().send_realtime_input(*args, **kwargs) except Exception as e: # reconnect here in some cases self.logger.error(e) @@ -141,16 +167,16 @@ async def send_client_content(self, *args, **kwargs): """ Don't use send client content, it can cause bugs when combined with send_realtime_input """ - await self._require_session().send_client_content( - *args, **kwargs - ) + await self._require_session().send_client_content(*args, **kwargs) async def connect(self): """ Connect to Gemini's websocket """ self.logger.info("Connecting to gemini live, config set to %s", self.config) - self._session_context = self.client.aio.live.connect(model=self.model, config=self._get_config_with_resumption()) + self._session_context = self.client.aio.live.connect( + model=self.model, config=self._get_config_with_resumption() + ) self._session = await self._session_context.__aenter__() self.connected = True self.logger.info("Gemini live connected to session %s", self._session) @@ -158,7 +184,6 @@ async def connect(self): # Start the receive loop task self._receive_task = asyncio.create_task(self._receive_loop()) - async def _reconnect(self): await self.connect() @@ -173,24 +198,66 @@ async def _receive_loop(self): async for response in self._require_session().receive(): server_message: LiveServerMessage = response - is_input_transcript = server_message and server_message.server_content and server_message.server_content.input_transcription - is_output_transcript = server_message and server_message.server_content and server_message.server_content.output_transcription - is_response = server_message and server_message.server_content and server_message.server_content.model_turn - is_interrupt = server_message and server_message.server_content and server_message.server_content.interrupted - is_turn_complete = server_message and server_message.server_content and server_message.server_content.turn_complete - is_generation_complete = server_message and server_message.server_content and server_message.server_content.generation_complete + is_input_transcript = ( + server_message + and server_message.server_content + and server_message.server_content.input_transcription + ) + is_output_transcript = ( + server_message + and server_message.server_content + and server_message.server_content.output_transcription + ) + is_response = ( + server_message + and server_message.server_content + and server_message.server_content.model_turn + ) + is_interrupt = ( + server_message + and server_message.server_content + and server_message.server_content.interrupted + ) + is_turn_complete = ( + server_message + and server_message.server_content + and server_message.server_content.turn_complete + ) + is_generation_complete = ( + server_message + and server_message.server_content + and server_message.server_content.generation_complete + ) if is_input_transcript: # TODO: what to do with this? check with Tommaso - if server_message.server_content and server_message.server_content.input_transcription: - self.logger.info("input: %s", server_message.server_content.input_transcription.text) + if ( + server_message.server_content + and server_message.server_content.input_transcription + ): + self.logger.info( + "input: %s", + server_message.server_content.input_transcription.text, + ) elif is_output_transcript: # TODO: what to do with this? - if server_message.server_content and server_message.server_content.output_transcription: - self.logger.info("output: %s", server_message.server_content.output_transcription.text) + if ( + server_message.server_content + and server_message.server_content.output_transcription + ): + self.logger.info( + "output: %s", + server_message.server_content.output_transcription.text, + ) elif is_interrupt: - if server_message.server_content and server_message.server_content.interrupted: - self.logger.info("interrupted: %s", server_message.server_content.interrupted) + if ( + server_message.server_content + and server_message.server_content.interrupted + ): + self.logger.info( + "interrupted: %s", + server_message.server_content.interrupted, + ) elif is_response: # Store the resumption id so we can resume a broken connection if server_message.session_resumption_update: @@ -198,7 +265,10 @@ async def _receive_loop(self): if update.resumable and update.new_handle: self.session_resumption_id = update.new_handle - if server_message.server_content and server_message.server_content.model_turn: + if ( + server_message.server_content + and server_message.server_content.model_turn + ): parts = server_message.server_content.model_turn.parts if parts: @@ -206,41 +276,60 @@ async def _receive_loop(self): typed_part: Part = current_part if typed_part.text: if typed_part.thought: - self.logger.info("Gemini thought %s", typed_part.text) + self.logger.info( + "Gemini thought %s", typed_part.text + ) else: - self.logger.info("output: %s", typed_part.text) + self.logger.info( + "output: %s", typed_part.text + ) event = LLMResponseChunkEvent( delta=typed_part.text ) self.events.send(event) elif typed_part.inline_data: data = typed_part.inline_data.data - + # Emit audio output event audio_event = RealtimeAudioOutputEvent( plugin_name="gemini", audio_data=data, - sample_rate=24000 + sample_rate=24000, ) self.events.send(audio_event) - - await self.output_track.write(data) # original 24khz here - elif hasattr(typed_part, 'function_call') and typed_part.function_call: + + await self.output_track.write( + data + ) # original 24khz here + elif ( + hasattr(typed_part, "function_call") + and typed_part.function_call + ): # Handle function calls from Gemini Live - self.logger.info(f"Received function call: {typed_part.function_call.name}") - await self._handle_function_call(typed_part.function_call) + self.logger.info( + f"Received function call: {typed_part.function_call.name}" + ) + await self._handle_function_call( + typed_part.function_call + ) else: - self.logger.debug("Unrecognized part type: %s", typed_part) + self.logger.debug( + "Unrecognized part type: %s", typed_part + ) elif is_turn_complete: self.logger.info("is_turn_complete complete") elif is_generation_complete: self.logger.info("is_generation_complete complete") elif server_message.tool_call: # Handle tool calls from Gemini Live - self.logger.info(f"Received tool call: {server_message.tool_call}") + self.logger.info( + f"Received tool call: {server_message.tool_call}" + ) await self._handle_tool_call(server_message.tool_call) else: - self.logger.warning("Unrecognized event structure for gemini %s", server_message) + self.logger.warning( + "Unrecognized event structure for gemini %s", server_message + ) except Exception as e: # reconnect here for some errors self.logger.error(f"_receive_loop error: {e}") @@ -264,11 +353,11 @@ def _is_temporary_error(e: Exception): async def _close_impl(self): self.connected = False - if hasattr(self, '_receive_task') and self._receive_task: + if hasattr(self, "_receive_task") and self._receive_task: self._receive_task.cancel() await self._receive_task - if hasattr(self, '_session_context') and self._session_context: + if hasattr(self, "_session_context") and self._session_context: # Properly close the session using the context manager's __aexit__ try: await self._session_context.__aexit__(None, None, None) @@ -277,30 +366,29 @@ async def _close_impl(self): self._session_context = None self._session = None - async def _watch_video_track(self, track: Any, **kwargs) -> None: """ Start sending video frames to Gemini using VideoForwarder. We follow the on_track from Stream. If video is turned on or off this gets forwarded. - + Args: track: Video track to watch shared_forwarder: Optional shared VideoForwarder to use instead of creating a new one """ - shared_forwarder = kwargs.get('shared_forwarder') - + shared_forwarder = kwargs.get("shared_forwarder") + if self._video_forwarder is not None and shared_forwarder is None: self.logger.warning("Video sender already running, stopping previous one") await self._stop_watching_video_track() - + if shared_forwarder is not None: # Use the shared forwarder - just register as a consumer self._video_forwarder = shared_forwarder - self.logger.info(f"🎥 Gemini subscribing to shared VideoForwarder at {self.fps} FPS") + self.logger.info( + f"🎥 Gemini subscribing to shared VideoForwarder at {self.fps} FPS" + ) await self._video_forwarder.start_event_consumer( - self._send_video_frame, - fps=float(self.fps), - consumer_name="gemini" + self._send_video_frame, fps=float(self.fps), consumer_name="gemini" ) else: # Create our own VideoForwarder with the input track (legacy behavior) @@ -310,13 +398,13 @@ async def _watch_video_track(self, track: Any, **kwargs) -> None: fps=float(self.fps), name="gemini_forwarder", ) - + # Start the forwarder await self._video_forwarder.start() - + # Start the callback consumer that sends frames to Gemini await self._video_forwarder.start_event_consumer(self._send_video_frame) - + self.logger.info(f"Started video forwarding with {self.fps} FPS") async def _stop_watching_video_track(self) -> None: @@ -339,7 +427,9 @@ async def _send_video_frame(self, frame: av.VideoFrame) -> None: except Exception as e: self.logger.error(f"Error sending video frame: {e}") - def _create_config(self, config: Optional[LiveConnectConfigDict]=None) -> LiveConnectConfigDict: + def _create_config( + self, config: Optional[LiveConnectConfigDict] = None + ) -> LiveConnectConfigDict: """ _create_config combines the default config with your settings """ @@ -362,10 +452,10 @@ def _create_config(self, config: Optional[LiveConnectConfigDict]=None) -> LiveCo sliding_window=SlidingWindowDict(target_tokens=12800), ), ) - + # Note: Tools will be added later in _get_config_with_resumption() # when functions are actually registered - + if config is not None: for k, v in config.items(): if k in default_config: @@ -379,12 +469,14 @@ def _get_config_with_resumption(self) -> LiveConnectConfigDict: config = self.config.copy() # resume if we have a session resumption id/handle if self.session_resumption_id: - resumption_config: SessionResumptionConfigDict = {"handle": self.session_resumption_id} # type: ignore[typeddict-item] + resumption_config: SessionResumptionConfigDict = { + "handle": self.session_resumption_id + } # type: ignore[typeddict-item] config["session_resumption"] = resumption_config # type: ignore[typeddict-item] # set the instructions # TODO: potentially we can share the markdown as files/parts.. might do better TBD config["system_instruction"] = self._build_enhanced_instructions() - + # Add tools if available - Gemini Live uses similar format to regular Gemini tools_spec = self.get_available_functions() if tools_spec: @@ -395,58 +487,69 @@ def _get_config_with_resumption(self) -> LiveConnectConfigDict: self.logger.info(f"Added {len(tools_spec)} tools to Gemini Live config") else: self.logger.debug("No tools available - function calling will not work") - + return config - def _convert_tools_to_provider_format(self, tools: List[ToolSchema]) -> List[Dict[str, Any]]: + def _convert_tools_to_provider_format( + self, tools: List[ToolSchema] + ) -> List[Dict[str, Any]]: """ Convert ToolSchema objects to Gemini Live format. - + Args: tools: List of ToolSchema objects - + Returns: List of tools in Gemini Live format """ function_declarations = [] for tool in tools: - function_declarations.append({ - "name": tool["name"], - "description": tool.get("description", ""), - "parameters": tool["parameters_schema"] - }) - + function_declarations.append( + { + "name": tool["name"], + "description": tool.get("description", ""), + "parameters": tool["parameters_schema"], + } + ) + # Return as dict with function_declarations (similar to regular Gemini format) return [{"function_declarations": function_declarations}] - def _extract_tool_calls_from_response(self, response: Any) -> List[NormalizedToolCallItem]: + def _extract_tool_calls_from_response( + self, response: Any + ) -> List[NormalizedToolCallItem]: """ Extract tool calls from Gemini Live response. - + Args: response: Gemini Live response object - + Returns: List of normalized tool call items """ calls: List[NormalizedToolCallItem] = [] - + try: # Check for function calls in the response - if hasattr(response, 'server_content') and response.server_content: - if hasattr(response.server_content, 'model_turn') and response.server_content.model_turn: + if hasattr(response, "server_content") and response.server_content: + if ( + hasattr(response.server_content, "model_turn") + and response.server_content.model_turn + ): parts = response.server_content.model_turn.parts for part in parts: - if hasattr(part, 'function_call') and part.function_call: + if hasattr(part, "function_call") and part.function_call: call_item: NormalizedToolCallItem = { "type": "tool_call", "name": getattr(part.function_call, "name", "unknown"), - "arguments_json": getattr(part.function_call, "args", {}) + "arguments_json": getattr( + part.function_call, "args", {} + ), } calls.append(call_item) except Exception as e: self.logger.debug(f"Error extracting tool calls from response: {e}") - + return calls async def _handle_tool_call(self, tool_call: Any) -> None: @@ -454,7 +557,7 @@ async def _handle_tool_call(self, tool_call: Any) -> None: Handle tool calls from Gemini Live. """ try: - if hasattr(tool_call, 'function_calls') and tool_call.function_calls: + if hasattr(tool_call, "function_calls") and tool_call.function_calls: for function_call in tool_call.function_calls: await self._handle_function_call(function_call) except Exception as e: @@ -463,7 +566,7 @@ async def _handle_tool_call(self, tool_call: Any) -> None: async def _handle_function_call(self, function_call: Any) -> None: """ Handle function calls from Gemini Live responses. - + Args: function_call: Function call object from Gemini Live """ @@ -472,14 +575,16 @@ async def _handle_function_call(self, function_call: Any) -> None: tool_call = { "name": getattr(function_call, "name", "unknown"), "arguments_json": getattr(function_call, "args", {}), - "id": getattr(function_call, "id", None) + "id": getattr(function_call, "id", None), } - - self.logger.info(f"Executing function call: {tool_call['name']} with args: {tool_call['arguments_json']}") - + + self.logger.info( + f"Executing function call: {tool_call['name']} with args: {tool_call['arguments_json']}" + ) + # Execute using existing tool execution infrastructure tc, result, error = await self._run_one_tool(tool_call, timeout_s=30) - + # Prepare response data if error: response_data = {"error": str(error)} @@ -490,29 +595,36 @@ async def _handle_function_call(self, function_call: Any) -> None: response_data = {"result": result} else: response_data = result - self.logger.info(f"Function call {tool_call['name']} succeeded: {response_data}") - + self.logger.info( + f"Function call {tool_call['name']} succeeded: {response_data}" + ) + # Send function response back to Gemini Live session call_id_val = tool_call.get("id") await self._send_function_response( - str(tool_call["name"]), - response_data, - str(call_id_val) if call_id_val else None + str(tool_call["name"]), + response_data, + str(call_id_val) if call_id_val else None, ) - + except Exception as e: self.logger.error(f"Error handling function call: {e}") # Send error response back await self._send_function_response( - getattr(function_call, "name", "unknown"), - {"error": str(e)}, - getattr(function_call, "id", None) + getattr(function_call, "name", "unknown"), + {"error": str(e)}, + getattr(function_call, "id", None), ) - async def _send_function_response(self, function_name: str, response_data: Dict[str, Any], call_id: Optional[str] = None) -> None: + async def _send_function_response( + self, + function_name: str, + response_data: Dict[str, Any], + call_id: Optional[str] = None, + ) -> None: """ Send function response back to Gemini Live session. - + Args: function_name: Name of the function that was called response_data: Response data to send back @@ -521,20 +633,26 @@ async def _send_function_response(self, function_name: str, response_data: Dict[ try: # Create function response part from google.genai import types - + function_response = types.FunctionResponse( id=call_id, # Use the call_id if provided name=function_name, - response=response_data + response=response_data, ) - + # Send the function response using the correct method # The Gemini Live API uses send_tool_response for function responses - await self._require_session().send_tool_response(function_responses=[function_response]) - self.logger.debug(f"Sent function response for {function_name}: {response_data}") - + await self._require_session().send_tool_response( + function_responses=[function_response] + ) + self.logger.debug( + f"Sent function response for {function_name}: {response_data}" + ) + except Exception as e: - self.logger.error(f"Error sending function response for {function_name}: {e}") + self.logger.error( + f"Error sending function response for {function_name}: {e}" + ) def _require_session(self) -> AsyncSession: if not self._session: diff --git a/plugins/getstream/vision_agents/plugins/getstream/__init__.py b/plugins/getstream/vision_agents/plugins/getstream/__init__.py index 1ab1ff62..27185bbb 100644 --- a/plugins/getstream/vision_agents/plugins/getstream/__init__.py +++ b/plugins/getstream/vision_agents/plugins/getstream/__init__.py @@ -6,4 +6,3 @@ from getstream import Stream as Client __all__ = ["Conversation", "Edge", "Client"] - diff --git a/plugins/getstream/vision_agents/plugins/getstream/stream_conversation.py b/plugins/getstream/vision_agents/plugins/getstream/stream_conversation.py index ce6a1c8b..949c0371 100644 --- a/plugins/getstream/vision_agents/plugins/getstream/stream_conversation.py +++ b/plugins/getstream/vision_agents/plugins/getstream/stream_conversation.py @@ -17,6 +17,7 @@ class StreamConversation(InMemoryConversation): """ Persists the message history to a stream channel & messages """ + messages: List[Message] # maps internal ids to stream message ids @@ -25,7 +26,13 @@ class StreamConversation(InMemoryConversation): channel: ChannelResponse chat_client: ChatClient - def __init__(self, instructions: str, messages: List[Message], channel: ChannelResponse, chat_client: ChatClient): + def __init__( + self, + instructions: str, + messages: List[Message], + channel: ChannelResponse, + chat_client: ChatClient, + ): super().__init__(instructions, messages) self.messages = messages self.channel = channel @@ -35,7 +42,9 @@ def __init__(self, instructions: str, messages: List[Message], channel: ChannelR # Initialize the worker thread for API calls self._api_queue: queue.Queue = queue.Queue() self._shutdown = False - self._worker_thread = threading.Thread(target=self._api_worker, daemon=True, name="StreamConversation-APIWorker") + self._worker_thread = threading.Thread( + target=self._api_worker, daemon=True, name="StreamConversation-APIWorker" + ) self._worker_thread.start() self._pending_operations = 0 self._operations_lock = threading.Lock() @@ -57,30 +66,34 @@ def _api_worker(self): response = self.chat_client.send_message( operation["channel_type"], operation["channel_id"], - operation["request"] + operation["request"], ) # Store the mapping - self.internal_ids_to_stream_ids[operation["internal_id"]] = response.data.message.id + self.internal_ids_to_stream_ids[operation["internal_id"]] = ( + response.data.message.id + ) operation["stream_id"] = response.data.message.id elif op_type == "update_message_partial": self.chat_client.update_message_partial( operation["stream_id"], user_id=operation["user_id"], - set=operation["set_data"] + set=operation["set_data"], ) elif op_type == "ephemeral_message_update": self.chat_client.ephemeral_message_update( operation["stream_id"], user_id=operation["user_id"], - set=operation["set_data"] + set=operation["set_data"], ) logger.debug(f"Successfully processed API operation: {op_type}") except Exception as e: - logger.error(f"Error processing API operation {operation.get('type', 'unknown')}: {e}") + logger.error( + f"Error processing API operation {operation.get('type', 'unknown')}: {e}" + ) # Continue processing other operations even if one fails finally: @@ -165,13 +178,20 @@ def queue_update_operation(): max_wait = 5.0 start_time = time.time() while time.time() - start_time < max_wait: - stream_id = self.internal_ids_to_stream_ids.get(message.id if message.id else "") + stream_id = self.internal_ids_to_stream_ids.get( + message.id if message.id else "" + ) if stream_id: update_op = { - "type": "update_message_partial" if completed else "ephemeral_message_update", + "type": "update_message_partial" + if completed + else "ephemeral_message_update", "stream_id": stream_id, "user_id": message.user_id, - "set_data": {"text": message.content, "generating": not completed}, + "set_data": { + "text": message.content, + "generating": not completed, + }, } with self._operations_lock: self._pending_operations += 1 @@ -183,7 +203,14 @@ def queue_update_operation(): # Queue the update in a separate thread to avoid blocking threading.Thread(target=queue_update_operation, daemon=True).start() - def update_message(self, message_id: str, input_text: str, user_id: str, replace_content: bool, completed: bool): + def update_message( + self, + message_id: str, + input_text: str, + user_id: str, + replace_content: bool, + completed: bool, + ): """Update a message in the Stream conversation. This method updates both the local message content and queues the Stream API sync. @@ -201,7 +228,9 @@ def update_message(self, message_id: str, input_text: str, user_id: str, replace None (operations are processed asynchronously) """ # First, update the local message using the superclass logic - super().update_message(message_id, input_text, user_id, replace_content, completed) + super().update_message( + message_id, input_text, user_id, replace_content, completed + ) # Get the updated message for Stream API sync message = self.lookup(message_id) @@ -212,12 +241,16 @@ def update_message(self, message_id: str, input_text: str, user_id: str, replace stream_id = self.internal_ids_to_stream_ids.get(message_id) if stream_id is None: - logger.warning(f"stream_id for message {message_id} not found, skipping Stream API update") + logger.warning( + f"stream_id for message {message_id} not found, skipping Stream API update" + ) return None # Queue the update operation update_op = { - "type": "update_message_partial" if completed else "ephemeral_message_update", + "type": "update_message_partial" + if completed + else "ephemeral_message_update", "stream_id": stream_id, "user_id": message.user_id, "set_data": {"text": message.content, "generating": not completed}, @@ -233,4 +266,4 @@ def __del__(self): try: self.shutdown() except Exception as e: - logger.error(f"Error during StreamConversation cleanup: {e}") \ No newline at end of file + logger.error(f"Error during StreamConversation cleanup: {e}") diff --git a/plugins/krisp/vision_agents/plugins/krisp/__init__.py b/plugins/krisp/vision_agents/plugins/krisp/__init__.py index 2b483a95..1ea797b1 100644 --- a/plugins/krisp/vision_agents/plugins/krisp/__init__.py +++ b/plugins/krisp/vision_agents/plugins/krisp/__init__.py @@ -3,4 +3,3 @@ from .turn_detection import TurnDetection __all__ = ["TurnDetection"] - diff --git a/plugins/krisp/vision_agents/plugins/krisp/turn_detection.py b/plugins/krisp/vision_agents/plugins/krisp/turn_detection.py index 0e3f07fc..c2ffed42 100644 --- a/plugins/krisp/vision_agents/plugins/krisp/turn_detection.py +++ b/plugins/krisp/vision_agents/plugins/krisp/turn_detection.py @@ -48,7 +48,7 @@ def __init__( ): super().__init__( confidence_threshold=confidence_threshold, - provider_name="KrispTurnDetection" + provider_name="KrispTurnDetection", ) self.logger = logging.getLogger("KrispTurnDetection") self.model_path = model_path diff --git a/plugins/moonshine/vision_agents/plugins/moonshine/stt.py b/plugins/moonshine/vision_agents/plugins/moonshine/stt.py index afe0b25c..bd1220ac 100644 --- a/plugins/moonshine/vision_agents/plugins/moonshine/stt.py +++ b/plugins/moonshine/vision_agents/plugins/moonshine/stt.py @@ -263,7 +263,9 @@ async def _transcribe_audio( return None async def _process_audio_impl( - self, pcm_data: PcmData, user_metadata: Optional[Union[Dict[str, Any], "Participant"]] = None + self, + pcm_data: PcmData, + user_metadata: Optional[Union[Dict[str, Any], "Participant"]] = None, ) -> Optional[List[Tuple[bool, str, Dict[str, Any]]]]: """ Process audio data through Moonshine for transcription. diff --git a/plugins/openai/tests/test_openai_llm.py b/plugins/openai/tests/test_openai_llm.py index fb804b62..5fb4475c 100644 --- a/plugins/openai/tests/test_openai_llm.py +++ b/plugins/openai/tests/test_openai_llm.py @@ -48,31 +48,27 @@ async def test_simple(self, llm: OpenAILLM): @pytest.mark.integration async def test_native_api(self, llm: OpenAILLM): - - response = await llm.create_response( input="say hi", instructions="You are a helpful assistant." ) # Assertions assert response.text - assert hasattr(response.original, 'id') # OpenAI response has id - + assert hasattr(response.original, "id") # OpenAI response has id @pytest.mark.integration async def test_streaming(self, llm: OpenAILLM): - streamingWorks = False - + @llm.events.subscribe async def passed(event: LLMResponseChunkEvent): nonlocal streamingWorks streamingWorks = True - + response = await llm.simple_response( "Explain quantum computing in 1 paragraph", ) - + await llm.events.wait() assert response.text diff --git a/plugins/openai/tests/test_openai_realtime.py b/plugins/openai/tests/test_openai_realtime.py index cfd9ca6e..c1991032 100644 --- a/plugins/openai/tests/test_openai_realtime.py +++ b/plugins/openai/tests/test_openai_realtime.py @@ -29,11 +29,11 @@ async def test_simple_response_flow(self, realtime): """Test sending a simple text message and receiving response""" # Send a simple message events = [] - + @realtime.events.subscribe async def on_audio(event: RealtimeAudioOutputEvent): events.append(event) - + await asyncio.sleep(0.01) await realtime.connect() await realtime.simple_response("Hello, can you hear me?") @@ -46,38 +46,36 @@ async def on_audio(event: RealtimeAudioOutputEvent): async def test_audio_sending_flow(self, realtime, mia_audio_16khz): """Test sending real audio data and verify connection remains stable""" events = [] - + @realtime.events.subscribe async def on_audio(event: RealtimeAudioOutputEvent): events.append(event) - + await asyncio.sleep(0.01) await realtime.connect() - + # Wait for connection to be fully established await asyncio.sleep(2.0) - + # Convert 16kHz audio to 48kHz for OpenAI realtime # OpenAI expects 48kHz PCM audio import numpy as np from scipy import signal from vision_agents.core.edge.types import PcmData - + # Resample from 16kHz to 48kHz samples_16k = mia_audio_16khz.samples num_samples_48k = int(len(samples_16k) * 48000 / 16000) samples_48k = signal.resample(samples_16k, num_samples_48k).astype(np.int16) - + # Create new PcmData with 48kHz - audio_48khz = PcmData( - samples=samples_48k, - sample_rate=48000, - format="s16" + audio_48khz = PcmData(samples=samples_48k, sample_rate=48000, format="s16") + + await realtime.simple_response( + "Listen to the following audio and tell me what you hear" ) - - await realtime.simple_response("Listen to the following audio and tell me what you hear") await asyncio.sleep(5.0) - + # Send the resampled audio await realtime.simple_audio_response(audio_48khz) @@ -89,22 +87,21 @@ async def on_audio(event: RealtimeAudioOutputEvent): async def test_video_sending_flow(self, realtime, bunny_video_track): """Test sending real video data and verify connection remains stable""" events = [] - + @realtime.events.subscribe async def on_audio(event: RealtimeAudioOutputEvent): events.append(event) - + await asyncio.sleep(0.01) await realtime.connect() await realtime.simple_response("Describe what you see in this video please") await asyncio.sleep(10.0) # Start video sender with low FPS to avoid overwhelming the connection await realtime._watch_video_track(bunny_video_track) - + # Let it run for a few seconds await asyncio.sleep(10.0) - + # Stop video sender await realtime._stop_watching_video_track() assert len(events) > 0 - diff --git a/plugins/openai/vision_agents/plugins/openai/__init__.py b/plugins/openai/vision_agents/plugins/openai/__init__.py index be4ca2e4..1535cfd5 100644 --- a/plugins/openai/vision_agents/plugins/openai/__init__.py +++ b/plugins/openai/vision_agents/plugins/openai/__init__.py @@ -1,6 +1,4 @@ - from .openai_llm import OpenAILLM as LLM from .openai_realtime import Realtime __all__ = ["Realtime", "LLM"] - diff --git a/plugins/openai/vision_agents/plugins/openai/events.py b/plugins/openai/vision_agents/plugins/openai/events.py index 56eb06a8..b82ac2ff 100644 --- a/plugins/openai/vision_agents/plugins/openai/events.py +++ b/plugins/openai/vision_agents/plugins/openai/events.py @@ -6,7 +6,8 @@ @dataclass class OpenAIStreamEvent(PluginBaseEvent): """Event emitted when OpenAI provides a stream event.""" - type: str = field(default='plugin.openai.stream', init=False) + + type: str = field(default="plugin.openai.stream", init=False) event_type: Optional[str] = None event_data: Optional[Any] = None @@ -14,6 +15,7 @@ class OpenAIStreamEvent(PluginBaseEvent): @dataclass class LLMErrorEvent(PluginBaseEvent): """Event emitted when an LLM encounters an error.""" - type: str = field(default='plugin.llm.error', init=False) + + type: str = field(default="plugin.llm.error", init=False) error_message: Optional[str] = None event_data: Optional[Any] = None diff --git a/plugins/openai/vision_agents/plugins/openai/openai_llm.py b/plugins/openai/vision_agents/plugins/openai/openai_llm.py index c16a850c..d3c22a33 100644 --- a/plugins/openai/vision_agents/plugins/openai/openai_llm.py +++ b/plugins/openai/vision_agents/plugins/openai/openai_llm.py @@ -3,13 +3,20 @@ from openai import AsyncOpenAI from openai.lib.streaming.responses import ResponseStreamEvent -from openai.types.responses import ResponseCompletedEvent, ResponseTextDeltaEvent, Response as OpenAIResponse +from openai.types.responses import ( + ResponseCompletedEvent, + ResponseTextDeltaEvent, + Response as OpenAIResponse, +) from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import Participant from vision_agents.core.llm.llm import LLM, LLMResponseEvent from vision_agents.core.llm.llm_types import ToolSchema, NormalizedToolCallItem -from vision_agents.core.llm.events import LLMResponseChunkEvent, LLMResponseCompletedEvent +from vision_agents.core.llm.events import ( + LLMResponseChunkEvent, + LLMResponseCompletedEvent, +) from . import events from vision_agents.core.processors import Processor @@ -42,7 +49,12 @@ class OpenAILLM(LLM): """ - def __init__(self, model: str, api_key: Optional[str] = None, client: Optional[AsyncOpenAI] = None): + def __init__( + self, + model: str, + api_key: Optional[str] = None, + client: Optional[AsyncOpenAI] = None, + ): """ Initialize the OpenAILLM class. @@ -64,8 +76,12 @@ def __init__(self, model: str, api_key: Optional[str] = None, client: Optional[A else: self.client = AsyncOpenAI() - async def simple_response(self, text: str, processors: Optional[List[Processor]] = None, - participant: Participant = None): + async def simple_response( + self, + text: str, + processors: Optional[List[Processor]] = None, + participant: Participant = None, + ): """ simple_response is a standardized way (across openai, claude, gemini etc.) to create a response. @@ -80,7 +96,7 @@ async def simple_response(self, text: str, processors: Optional[List[Processor]] """ # Use enhanced instructions if available (includes markdown file contents) instructions = None - if hasattr(self, 'parsed_instructions') and self.parsed_instructions: + if hasattr(self, "parsed_instructions") and self.parsed_instructions: instructions = self._build_enhanced_instructions() elif self.conversation is not None: instructions = self.conversation.instructions @@ -90,7 +106,9 @@ async def simple_response(self, text: str, processors: Optional[List[Processor]] instructions=instructions, ) - async def create_response(self, *args: Any, **kwargs: Any) -> LLMResponseEvent[OpenAIResponse]: + async def create_response( + self, *args: Any, **kwargs: Any + ) -> LLMResponseEvent[OpenAIResponse]: """ create_response gives you full support/access to the native openAI responses.create method this method wraps the openAI method and ensures we broadcast an event which the agent class hooks into @@ -110,8 +128,7 @@ async def create_response(self, *args: Any, **kwargs: Any) -> LLMResponseEvent[O kwargs["tools"] = self._convert_tools_to_provider_format(tools_spec) # type: ignore[arg-type] # Use parsed instructions if available (includes markdown file contents) - if hasattr(self, 'parsed_instructions') and self.parsed_instructions: - + if hasattr(self, "parsed_instructions") and self.parsed_instructions: # Combine original instructions with markdown file contents enhanced_instructions = self._build_enhanced_instructions() if enhanced_instructions: @@ -122,43 +139,49 @@ async def create_response(self, *args: Any, **kwargs: Any) -> LLMResponseEvent[O # Use the first positional argument as input, or create a default input_content = args[0] if args else "Hello" kwargs["input"] = input_content - + # OpenAI Responses API only accepts keyword arguments response = await self.client.responses.create(**kwargs) - llm_response : Optional[LLMResponseEvent[OpenAIResponse]] = None + llm_response: Optional[LLMResponseEvent[OpenAIResponse]] = None if isinstance(response, OpenAIResponse): # Non-streaming response - llm_response = LLMResponseEvent[OpenAIResponse](response, response.output_text) - + llm_response = LLMResponseEvent[OpenAIResponse]( + response, response.output_text + ) + # Check for tool calls in non-streaming response tool_calls = self._extract_tool_calls_from_response(response) if tool_calls: # Execute tools and get follow-up response llm_response = await self._handle_tool_calls(tool_calls, kwargs) - + elif hasattr(response, "__aiter__"): # async stream # Streaming response stream_response = response pending_tool_calls = [] seen = set() - + # Process streaming events and collect tool calls async for event in stream_response: llm_response_optional = self._standardize_and_emit_event(event) if llm_response_optional is not None: llm_response = llm_response_optional - + # Grab tool calls when the model finalizes the turn if getattr(event, "type", "") == "response.completed": calls = self._extract_tool_calls_from_response(event.response) for c in calls: - key = (c["id"], c["name"], json.dumps(c["arguments_json"], sort_keys=True)) + key = ( + c["id"], + c["name"], + json.dumps(c["arguments_json"], sort_keys=True), + ) if key not in seen: pending_tool_calls.append(c) seen.add(key) - + # If we have tool calls, execute them and get follow-up response if pending_tool_calls: llm_response = await self._handle_tool_calls(pending_tool_calls, kwargs) @@ -171,22 +194,25 @@ async def create_response(self, *args: Any, **kwargs: Any) -> LLMResponseEvent[O # Only emit it here for non-streaming responses to avoid duplication. if llm_response is not None and isinstance(response, OpenAIResponse): # Non-streaming response - emit completion event - self.events.send(LLMResponseCompletedEvent( - original=llm_response.original, - text=llm_response.text - )) + self.events.send( + LLMResponseCompletedEvent( + original=llm_response.original, text=llm_response.text + ) + ) return llm_response or LLMResponseEvent[OpenAIResponse](None, "") # type: ignore[arg-type] - async def _handle_tool_calls(self, tool_calls: List[NormalizedToolCallItem], original_kwargs: Dict[str, Any]) -> LLMResponseEvent[OpenAIResponse]: + async def _handle_tool_calls( + self, tool_calls: List[NormalizedToolCallItem], original_kwargs: Dict[str, Any] + ) -> LLMResponseEvent[OpenAIResponse]: """ Handle tool calls by executing them and getting a follow-up response. Supports multi-round tool calling (max 3 rounds). - + Args: tool_calls: List of tool calls to execute original_kwargs: Original kwargs from the request - + Returns: LLM response with tool results """ @@ -195,15 +221,17 @@ async def _handle_tool_calls(self, tool_calls: List[NormalizedToolCallItem], ori current_tool_calls = tool_calls current_kwargs = original_kwargs.copy() seen: set[tuple] = set() - + for round_num in range(max_rounds): # Execute tools (with cross-round deduplication) - triples, seen = await self._dedup_and_execute(current_tool_calls, max_concurrency=8, timeout_s=30, seen=seen) # type: ignore[arg-type] - + triples, seen = await self._dedup_and_execute( + current_tool_calls, max_concurrency=8, timeout_s=30, seen=seen + ) # type: ignore[arg-type] + # If no tools were executed, break the loop if not triples: break - + # Process all tool calls, including failed ones tool_messages = [] for tc, res, err in triples: @@ -211,74 +239,86 @@ async def _handle_tool_calls(self, tool_calls: List[NormalizedToolCallItem], ori if not cid: # Skip tool calls without ID - they can't be reported back continue - + # Use error result if there was an error, otherwise use the result output = err if err is not None else res - + # Convert to string for OpenAI Responses API with sanitization output_str = self._sanitize_tool_output(output) - tool_messages.append({ - "type": "function_call_output", - "call_id": cid, - "output": output_str, - }) - + tool_messages.append( + { + "type": "function_call_output", + "call_id": cid, + "output": output_str, + } + ) + # Don't send empty tool result inputs if not tool_messages: return llm_response or LLMResponseEvent[OpenAIResponse](None, "") # type: ignore[arg-type] - + # Send follow-up request with tool results if not self.openai_conversation: return llm_response or LLMResponseEvent[OpenAIResponse](None, "") # type: ignore[arg-type] - + follow_up_kwargs = { "model": current_kwargs.get("model", self.model), "conversation": self.openai_conversation.id, "input": tool_messages, "stream": True, } - + # Include tools again for potential follow-up calls tools_spec = self._get_tools_for_provider() if tools_spec: - follow_up_kwargs["tools"] = self._convert_tools_to_provider_format(tools_spec) # type: ignore[arg-type] - + follow_up_kwargs["tools"] = self._convert_tools_to_provider_format( + tools_spec + ) # type: ignore[arg-type] + # Get follow-up response follow_up_response = await self.client.responses.create(**follow_up_kwargs) - + if isinstance(follow_up_response, OpenAIResponse): # Non-streaming response - llm_response = LLMResponseEvent[OpenAIResponse](follow_up_response, follow_up_response.output_text) - + llm_response = LLMResponseEvent[OpenAIResponse]( + follow_up_response, follow_up_response.output_text + ) + # Check for more tool calls - next_tool_calls = self._extract_tool_calls_from_response(follow_up_response) + next_tool_calls = self._extract_tool_calls_from_response( + follow_up_response + ) if next_tool_calls and round_num < max_rounds - 1: current_tool_calls = next_tool_calls current_kwargs = follow_up_kwargs continue else: return llm_response - + elif hasattr(follow_up_response, "__aiter__"): # async stream stream_response = follow_up_response llm_response = None pending_tool_calls = [] # Don't reset seen - keep deduplication across rounds - + async for event in stream_response: llm_response_optional = self._standardize_and_emit_event(event) if llm_response_optional is not None: llm_response = llm_response_optional - + # Check for more tool calls if getattr(event, "type", "") == "response.completed": calls = self._extract_tool_calls_from_response(event.response) for c in calls: - key = (c["id"], c["name"], json.dumps(c["arguments_json"], sort_keys=True)) + key = ( + c["id"], + c["name"], + json.dumps(c["arguments_json"], sort_keys=True), + ) if key not in seen: pending_tool_calls.append(c) seen.add(key) - + # If we have more tool calls and haven't exceeded max rounds, continue if pending_tool_calls and round_num < max_rounds - 1: current_tool_calls = pending_tool_calls @@ -289,7 +329,7 @@ async def _handle_tool_calls(self, tool_calls: List[NormalizedToolCallItem], ori else: # Defensive fallback return LLMResponseEvent[OpenAIResponse](None, "") # type: ignore[arg-type] - + # If we've exhausted all rounds, return the last response return llm_response or LLMResponseEvent[OpenAIResponse](None, "") # type: ignore[arg-type] @@ -302,9 +342,7 @@ def _normalize_message(openai_input) -> List["Message"]: # standardize on input if isinstance(openai_input, str): - openai_input = [ - dict(content=openai_input, role="user", type="message") - ] + openai_input = [dict(content=openai_input, role="user", type="message")] elif not isinstance(openai_input, List): openai_input = [openai_input] @@ -316,15 +354,15 @@ def _normalize_message(openai_input) -> List["Message"]: return messages - - - def _convert_tools_to_provider_format(self, tools: List[ToolSchema]) -> List[Dict[str, Any]]: + def _convert_tools_to_provider_format( + self, tools: List[ToolSchema] + ) -> List[Dict[str, Any]]: """ Convert ToolSchema objects to OpenAI Responses API format. - + Args: tools: List of ToolSchema objects from the function registry - + Returns: List of tools in OpenAI Responses API format """ @@ -339,22 +377,26 @@ def _convert_tools_to_provider_format(self, tools: List[ToolSchema]) -> List[Dic params.setdefault("properties", {}) params.setdefault("additionalProperties", False) - out.append({ - "type": "function", - "name": name, # <-- top-level - "description": description, # <-- top-level - "parameters": params, # <-- top-level - "strict": True, # optional but fine - }) + out.append( + { + "type": "function", + "name": name, # <-- top-level + "description": description, # <-- top-level + "parameters": params, # <-- top-level + "strict": True, # optional but fine + } + ) return out - def _extract_tool_calls_from_response(self, response: Any) -> List[NormalizedToolCallItem]: + def _extract_tool_calls_from_response( + self, response: Any + ) -> List[NormalizedToolCallItem]: """ Extract tool calls from OpenAI response. - + Args: response: OpenAI response object - + Returns: List of normalized tool call items """ @@ -368,22 +410,23 @@ def _extract_tool_calls_from_response(self, response: Any) -> List[NormalizedToo args_obj = {} call_item: NormalizedToolCallItem = { "type": "tool_call", - "id": getattr(item, "call_id", ""), # <-- call_id + "id": getattr(item, "call_id", ""), # <-- call_id "name": getattr(item, "name", "unknown"), "arguments_json": args_obj, } calls.append(call_item) return calls - - def _create_tool_result_message(self, tool_calls: List[NormalizedToolCallItem], results: List[Any]) -> List[Dict[str, Any]]: + def _create_tool_result_message( + self, tool_calls: List[NormalizedToolCallItem], results: List[Any] + ) -> List[Dict[str, Any]]: """ Create tool result messages for OpenAI Responses API. - + Args: tool_calls: List of tool calls that were executed results: List of results from function execution - + Returns: List of tool result messages in Responses API format """ @@ -393,56 +436,66 @@ def _create_tool_result_message(self, tool_calls: List[NormalizedToolCallItem], if not call_id: # skip or wrap into a normal assistant message / log an error continue - + # Send only function_call_output items keyed by call_id # Convert to string for Responses API output_str = res if isinstance(res, str) else json.dumps(res) - msgs.append({ - "type": "function_call_output", - "call_id": call_id, - "output": output_str, - }) + msgs.append( + { + "type": "function_call_output", + "call_id": call_id, + "output": output_str, + } + ) return msgs - def _standardize_and_emit_event(self, event: ResponseStreamEvent) -> Optional[LLMResponseEvent]: + def _standardize_and_emit_event( + self, event: ResponseStreamEvent + ) -> Optional[LLMResponseEvent]: """ Forwards the events and also send out a standardized version (the agent class hooks into that) """ # start by forwarding the native event - self.events.send(events.OpenAIStreamEvent( - plugin_name="openai", - event_type=event.type, - event_data=event - )) + self.events.send( + events.OpenAIStreamEvent( + plugin_name="openai", event_type=event.type, event_data=event + ) + ) if event.type == "response.error": # Handle error events error_message = getattr(event, "error", {}).get("message", "Unknown error") - self.events.send(events.LLMErrorEvent( - plugin_name="openai", - error_message=error_message, - event_data=event - )) + self.events.send( + events.LLMErrorEvent( + plugin_name="openai", error_message=error_message, event_data=event + ) + ) return None elif event.type == "response.output_text.delta": # standardize the delta event delta_event: ResponseTextDeltaEvent = event - self.events.send(LLMResponseChunkEvent( - plugin_name="openai", - content_index=delta_event.content_index, - item_id=delta_event.item_id, - output_index=delta_event.output_index, - sequence_number=delta_event.sequence_number, - delta=delta_event.delta, - )) + self.events.send( + LLMResponseChunkEvent( + plugin_name="openai", + content_index=delta_event.content_index, + item_id=delta_event.item_id, + output_index=delta_event.output_index, + sequence_number=delta_event.sequence_number, + delta=delta_event.delta, + ) + ) elif event.type == "response.completed": # standardize the response event and return the llm response completed_event: ResponseCompletedEvent = event - llm_response = LLMResponseEvent[OpenAIResponse](completed_event.response, completed_event.response.output_text) - self.events.send(LLMResponseCompletedEvent( - plugin_name="openai", - original=llm_response.original, - text=llm_response.text - )) + llm_response = LLMResponseEvent[OpenAIResponse]( + completed_event.response, completed_event.response.output_text + ) + self.events.send( + LLMResponseCompletedEvent( + plugin_name="openai", + original=llm_response.original, + text=llm_response.text, + ) + ) return llm_response return None diff --git a/plugins/openai/vision_agents/plugins/openai/openai_realtime.py b/plugins/openai/vision_agents/plugins/openai/openai_realtime.py index 2cec37d6..fe982542 100644 --- a/plugins/openai/vision_agents/plugins/openai/openai_realtime.py +++ b/plugins/openai/vision_agents/plugins/openai/openai_realtime.py @@ -2,8 +2,11 @@ from typing import Any, Optional, List, Dict from getstream.video.rtc.audio_track import AudioStreamTrack -from openai.types.realtime import RealtimeSessionCreateRequestParam, ResponseAudioTranscriptDoneEvent, \ - InputAudioBufferSpeechStartedEvent +from openai.types.realtime import ( + RealtimeSessionCreateRequestParam, + ResponseAudioTranscriptDoneEvent, + InputAudioBufferSpeechStartedEvent, +) from vision_agents.core.llm import realtime from vision_agents.core.llm.llm_types import ToolSchema @@ -50,16 +53,17 @@ class Realtime(realtime.Realtime): - MCP integration for external service access. """ - def __init__(self, model: str = "gpt-realtime", voice: str = "marin", *args, **kwargs): + + def __init__( + self, model: str = "gpt-realtime", voice: str = "marin", *args, **kwargs + ): super().__init__(*args, **kwargs) self.model = model self.voice = voice # TODO: send video should depend on if the RTC connection with stream is sending video. self.rtc = RTCManager(self.model, self.voice, True) # audio output track? - self.output_track = AudioStreamTrack( - framerate=48000, stereo=True, format="s16" - ) + self.output_track = AudioStreamTrack(framerate=48000, stereo=True, format="s16") async def connect(self): """Establish the WebRTC connection to OpenAI's Realtime API. @@ -74,23 +78,27 @@ async def connect(self): instructions = self.instructions self.rtc.instructions = instructions - + # Wire callbacks so we can emit audio/events upstream self.rtc.set_event_callback(self._handle_openai_event) self.rtc.set_audio_callback(self._handle_audio_output) await self.rtc.connect() - + # Register tools with OpenAI realtime if available await self._register_tools_with_openai_realtime() - + # Emit connected/ready self._emit_connected_event( session_config={"model": self.model, "voice": self.voice}, capabilities=["text", "audio", "function_calling"], ) - async def simple_response(self, text: str, processors: Optional[List[Processor]] = None, - participant: Optional[Participant] = None): + async def simple_response( + self, + text: str, + processors: Optional[List[Processor]] = None, + participant: Optional[Participant] = None, + ): """Send a simple text input to the OpenAI Realtime session. This is a convenience wrapper that forwards a text prompt upstream via @@ -150,9 +158,19 @@ async def _handle_openai_event(self, event: dict) -> None: """ et = event.get("type") if et == "response.audio_transcript.done": - transcript_event: ResponseAudioTranscriptDoneEvent = ResponseAudioTranscriptDoneEvent.model_validate(event) - self._emit_transcript_event(text=transcript_event.transcript, user_metadata={"role": "assistant", "source": "openai"}) - self._emit_response_event(text=transcript_event.transcript, response_id=transcript_event.response_id, is_complete=True, conversation_item_id=transcript_event.item_id) + transcript_event: ResponseAudioTranscriptDoneEvent = ( + ResponseAudioTranscriptDoneEvent.model_validate(event) + ) + self._emit_transcript_event( + text=transcript_event.transcript, + user_metadata={"role": "assistant", "source": "openai"}, + ) + self._emit_response_event( + text=transcript_event.transcript, + response_id=transcript_event.response_id, + is_complete=True, + conversation_item_id=transcript_event.item_id, + ) elif et == "input_audio_buffer.speech_started": # Validate event but don't need to store it InputAudioBufferSpeechStartedEvent.model_validate(event) @@ -182,20 +200,22 @@ async def _handle_audio_output(self, audio_bytes: bytes) -> None: audio_data=audio_bytes, sample_rate=48000, # OpenAI Realtime uses 48kHz ) - + # Forward audio to output track for playback await self.output_track.write(audio_bytes) async def _watch_video_track(self, track, **kwargs) -> None: - shared_forwarder = kwargs.get('shared_forwarder') - await self.rtc.start_video_sender(track, self.fps, shared_forwarder=shared_forwarder) + shared_forwarder = kwargs.get("shared_forwarder") + await self.rtc.start_video_sender( + track, self.fps, shared_forwarder=shared_forwarder + ) async def _stop_watching_video_track(self) -> None: await self.rtc.stop_video_sender() async def _handle_tool_call_event(self, event: dict) -> None: """Handle tool call events from OpenAI realtime. - + Args: event: Tool call event from OpenAI realtime API """ @@ -206,24 +226,26 @@ async def _handle_tool_call_event(self, event: dict) -> None: tool_call_data = item else: tool_call_data = event.get("tool_call", {}) - + if not tool_call_data: logger.warning("Received tool call event without tool_call data") return - + # Extract tool call details tool_call = { "type": "tool_call", "id": tool_call_data.get("call_id"), "name": tool_call_data.get("name", "unknown"), - "arguments_json": tool_call_data.get("arguments", {}) + "arguments_json": tool_call_data.get("arguments", {}), } - - logger.info(f"Executing tool call: {tool_call['name']} with args: {tool_call['arguments_json']}") - + + logger.info( + f"Executing tool call: {tool_call['name']} with args: {tool_call['arguments_json']}" + ) + # Execute using existing tool execution infrastructure tc, result, error = await self._run_one_tool(tool_call, timeout_s=30) - + # Prepare response data if error: response_data = {"error": str(error)} @@ -235,10 +257,10 @@ async def _handle_tool_call_event(self, event: dict) -> None: else: response_data = result logger.info(f"Tool call {tool_call['name']} succeeded: {response_data}") - + # Send tool response back to OpenAI realtime session await self._send_tool_response(tool_call["id"], response_data) - + except Exception as e: logger.error(f"Error handling tool call event: {e}") # Send error response back @@ -249,9 +271,11 @@ async def _handle_tool_call_event(self, event: dict) -> None: call_id = event.get("tool_call", {}).get("call_id") await self._send_tool_response(call_id, {"error": str(e)}) - async def _send_tool_response(self, call_id: Optional[str], response_data: Dict[str, Any]) -> None: + async def _send_tool_response( + self, call_id: Optional[str], response_data: Dict[str, Any] + ) -> None: """Send tool response back to OpenAI realtime session. - + Args: call_id: The call ID from the original tool call response_data: The response data to send back @@ -259,44 +283,46 @@ async def _send_tool_response(self, call_id: Optional[str], response_data: Dict[ if not call_id: logger.warning("Cannot send tool response without call_id") return - + try: # Convert response to string for OpenAI realtime response_str = self._sanitize_tool_output(response_data) - + # Send tool response event event = { "type": "conversation.item.create", "item": { "type": "function_call_output", "call_id": call_id, - "output": response_str - } + "output": response_str, + }, } - + await self.rtc._send_event(event) logger.info(f"Sent tool response for call_id {call_id}") - + # Trigger a new response to continue the conversation with audio # This ensures the AI responds with audio after receiving the tool result - await self.rtc._send_event({ - "type": "response.create", - "response": { - "modalities": ["text", "audio"], - "instructions": "Please respond to the user with the tool results in a conversational way." + await self.rtc._send_event( + { + "type": "response.create", + "response": { + "modalities": ["text", "audio"], + "instructions": "Please respond to the user with the tool results in a conversational way.", + }, } - }) - + ) + except Exception as e: logger.error(f"Failed to send tool response: {e}") def _sanitize_tool_output(self, value: Any, max_chars: int = 60_000) -> str: """Sanitize tool output for OpenAI realtime. - + Args: value: The tool output to sanitize max_chars: Maximum characters allowed (not used in realtime mode) - + Returns: Sanitized string output """ @@ -307,12 +333,14 @@ def _sanitize_tool_output(self, value: Any, max_chars: int = 60_000) -> str: else: return str(value) - def _convert_tools_to_openai_realtime_format(self, tools: List[ToolSchema]) -> List[Dict[str, Any]]: + def _convert_tools_to_openai_realtime_format( + self, tools: List[ToolSchema] + ) -> List[Dict[str, Any]]: """Convert ToolSchema objects to OpenAI realtime format. - + Args: tools: List of ToolSchema objects from the function registry - + Returns: List of tools in OpenAI realtime format """ @@ -327,46 +355,50 @@ def _convert_tools_to_openai_realtime_format(self, tools: List[ToolSchema]) -> L params.setdefault("properties", {}) params.setdefault("additionalProperties", False) - out.append({ - "type": "function", - "name": name, - "description": description, - "parameters": params, - }) + out.append( + { + "type": "function", + "name": name, + "description": description, + "parameters": params, + } + ) return out async def _register_tools_with_openai_realtime(self) -> None: """Register available tools with OpenAI realtime session. - + This method registers all available functions and MCP tools with the OpenAI realtime session so they can be called during conversations. """ try: # Get available tools from function registry available_tools = self.get_available_functions() - + if not available_tools: logger.info("No tools available to register with OpenAI realtime") return - + # Convert tools to OpenAI realtime format - tools_for_openai = self._convert_tools_to_openai_realtime_format(available_tools) - + tools_for_openai = self._convert_tools_to_openai_realtime_format( + available_tools + ) + if not tools_for_openai: logger.info("No tools converted for OpenAI realtime") return - + # Send tools configuration to OpenAI realtime tools_event = { "type": "session.update", - "session": { - "tools": tools_for_openai - } + "session": {"tools": tools_for_openai}, } - + await self.rtc._send_event(tools_event) - logger.info(f"Registered {len(tools_for_openai)} tools with OpenAI realtime") - + logger.info( + f"Registered {len(tools_for_openai)} tools with OpenAI realtime" + ) + except Exception as e: logger.error(f"Failed to register tools with OpenAI realtime: {e}") # Don't raise the exception - tool registration failure shouldn't break the connection diff --git a/plugins/openai/vision_agents/plugins/openai/rtc_manager.py b/plugins/openai/vision_agents/plugins/openai/rtc_manager.py index 7f82a32b..56faa45b 100644 --- a/plugins/openai/vision_agents/plugins/openai/rtc_manager.py +++ b/plugins/openai/vision_agents/plugins/openai/rtc_manager.py @@ -42,7 +42,7 @@ def __init__(self, sample_rate: int = 48000): self._latest_chunk: Optional[bytes] = None self._silence_cache: dict[int, np.ndarray] = {} - def set_input (self, pcm_data: bytes, sample_rate: Optional[int] = None) -> None: + def set_input(self, pcm_data: bytes, sample_rate: Optional[int] = None) -> None: if not pcm_data: return if sample_rate is not None: @@ -93,16 +93,18 @@ class StreamVideoForwardingTrack(VideoStreamTrack): """Track that forwards frames from Stream Video to OpenAI. TODO: why do we have this forwarding track, when there is the video_forwarder """ - + kind = "video" - - def __init__(self, source_track: MediaStreamTrack, fps: int = 1, shared_forwarder=None): + + def __init__( + self, source_track: MediaStreamTrack, fps: int = 1, shared_forwarder=None + ): super().__init__() self._source_track = source_track self._fps = max(1, fps) self._interval = 1.0 / self._fps self._ts = 0 - self._last_frame_time = 0. + self._last_frame_time = 0.0 self._frame_count = 0 self._error_count = 0 self._consecutive_errors = 0 @@ -116,27 +118,33 @@ def __init__(self, source_track: MediaStreamTrack, fps: int = 1, shared_forwarde self._started: bool = False # Rate limiting for inactive track warnings - self._last_inactive_warning = 0. + self._last_inactive_warning = 0.0 self._inactive_warning_interval = 30.0 # Only warn every 30 seconds - + if shared_forwarder: - logger.info(f"🎥 StreamVideoForwardingTrack initialized with SHARED forwarder: fps={fps}, interval={self._interval:.3f}s") + logger.info( + f"🎥 StreamVideoForwardingTrack initialized with SHARED forwarder: fps={fps}, interval={self._interval:.3f}s" + ) else: - logger.info(f"🎥 StreamVideoForwardingTrack initialized: fps={fps}, interval={self._interval:.3f}s (frame limiting DISABLED for performance)") - + logger.info( + f"🎥 StreamVideoForwardingTrack initialized: fps={fps}, interval={self._interval:.3f}s (frame limiting DISABLED for performance)" + ) + async def start(self) -> None: if self._started: return - + if self._shared_forwarder is not None: # Use the shared forwarder self._forwarder = self._shared_forwarder logger.info(f"🎥 OpenAI using shared VideoForwarder at {self._fps} FPS") else: # Create our own VideoForwarder with the input source track (legacy behavior) - self._forwarder = VideoForwarder(self._source_track, max_buffer=5, fps=self._fps) # type: ignore[arg-type] + self._forwarder = VideoForwarder( + self._source_track, max_buffer=5, fps=self._fps + ) # type: ignore[arg-type] await self._forwarder.start() - + self._started = True async def recv(self): @@ -148,7 +156,9 @@ async def recv(self): # Rate limit warnings to avoid spam now = time.monotonic() if now - self._last_inactive_warning > self._inactive_warning_interval: - logger.warning("🎥 StreamVideoForwardingTrack is no longer active, returning black frame") + logger.warning( + "🎥 StreamVideoForwardingTrack is no longer active, returning black frame" + ) self._last_inactive_warning = now return self._generate_black_frame() @@ -157,8 +167,12 @@ async def recv(self): # Health check: detect if track has been dead for too long if now - self._last_health_check > self._health_check_interval: self._last_health_check = now - if now - self._last_successful_frame_time > 30.0: # No frames for 30 seconds - logger.error("🎥 StreamVideoForwardingTrack health check failed - no frames for 30+ seconds") + if ( + now - self._last_successful_frame_time > 30.0 + ): # No frames for 30 seconds + logger.error( + "🎥 StreamVideoForwardingTrack health check failed - no frames for 30+ seconds" + ) self._is_active = False return self._generate_black_frame() @@ -178,7 +192,9 @@ async def recv(self): try: frame = frame.reformat(format="rgb24") except Exception as e: - logger.warning(f"🎥 Frame format conversion failed: {e}, using original") + logger.warning( + f"🎥 Frame format conversion failed: {e}, using original" + ) # Update timing for WebRTC frame.pts = self._ts @@ -191,18 +207,24 @@ async def recv(self): except asyncio.TimeoutError: self._consecutive_errors += 1 if self._consecutive_errors >= self._max_consecutive_errors: - logger.error(f"🎥 StreamVideoForwardingTrack circuit breaker triggered - {self._consecutive_errors} consecutive timeouts") + logger.error( + f"🎥 StreamVideoForwardingTrack circuit breaker triggered - {self._consecutive_errors} consecutive timeouts" + ) self._is_active = False return self._generate_black_frame() except Exception as e: self._consecutive_errors += 1 self._error_count += 1 - logger.error(f"❌ FRAME ERROR: frame_id={self._frame_count} (error #{self._error_count}, consecutive={self._consecutive_errors}): {e}") + logger.error( + f"❌ FRAME ERROR: frame_id={self._frame_count} (error #{self._error_count}, consecutive={self._consecutive_errors}): {e}" + ) if self._consecutive_errors >= self._max_consecutive_errors: - logger.error(f"🎥 StreamVideoForwardingTrack circuit breaker triggered - {self._consecutive_errors} consecutive errors") + logger.error( + f"🎥 StreamVideoForwardingTrack circuit breaker triggered - {self._consecutive_errors} consecutive errors" + ) self._is_active = False return self._generate_black_frame() - + def _generate_black_frame(self) -> VideoFrame: """Generate a black frame as fallback.""" black_array = np.zeros((480, 640, 3), dtype=np.uint8) @@ -211,10 +233,12 @@ def _generate_black_frame(self) -> VideoFrame: frame.time_base = Fraction(1, self._fps) self._ts += 1 return frame - + def stop(self): """Stop the forwarding track and forwarder.""" - logger.info(f"🎥 StreamVideoForwardingTrack stopped after {self._frame_count} frames, {self._error_count} errors") + logger.info( + f"🎥 StreamVideoForwardingTrack stopped after {self._frame_count} frames, {self._error_count} errors" + ) try: if self._forwarder is not None: asyncio.create_task(self._forwarder.stop()) @@ -281,7 +305,6 @@ async def on_track(track): await self.pc.setRemoteDescription(answer) logger.info("Remote description set; WebRTC established") - async def _get_session_token(self) -> str: url = OPENAI_SESSIONS_URL headers = { @@ -298,7 +321,9 @@ async def _get_session_token(self) -> str: async with AsyncClient() as client: for attempt in range(2): try: - resp = await client.post(url, headers=headers, json=payload, timeout=15) + resp = await client.post( + url, headers=headers, json=payload, timeout=15 + ) resp.raise_for_status() data: dict = resp.json() secret = data.get("client_secret", {}) @@ -321,15 +346,11 @@ async def on_open(): # Immediately switch to semantic VAD so it's active before the user speaks await self._send_event( - { - "type": "session.update", - "session": { - "turn_detection": { - "type": "semantic_vad" - } - }, - } - ) + { + "type": "session.update", + "session": {"turn_detection": {"type": "semantic_vad"}}, + } + ) # Session information will be automatically stored when session.created event is received logger.info("Requested semantic_vad via session.update") @@ -369,7 +390,6 @@ async def recv(self): # Keep a handle to the currently active source (if any) for diagnostics / control self._active_video_source: Optional[MediaStreamTrack] = None - async def send_audio_pcm(self, pcm_data: PcmData) -> None: """Send raw PCM audio data to OpenAI. @@ -394,10 +414,9 @@ async def send_audio_pcm(self, pcm_data: PcmData) -> None: except Exception as e: logger.error(f"Failed to push mic audio: {e}") - async def send_text(self, text: str, role: str = "user"): """Send a text message to OpenAI. - + Args: text: The text message to send. role: Message role. Defaults to "user". @@ -428,13 +447,19 @@ async def _send_event(self, event: dict): # Ensure the data channel is open before sending if not self._data_channel_open_event.is_set(): try: - await asyncio.wait_for(self._data_channel_open_event.wait(), timeout=5.0) + await asyncio.wait_for( + self._data_channel_open_event.wait(), timeout=5.0 + ) except asyncio.TimeoutError: - logger.warning("Data channel not open after timeout; dropping event") + logger.warning( + "Data channel not open after timeout; dropping event" + ) return if self.data_channel.readyState and self.data_channel.readyState != "open": - logger.warning(f"Data channel state is '{self.data_channel.readyState}', cannot send event") + logger.warning( + f"Data channel state is '{self.data_channel.readyState}', cannot send event" + ) message_json = json.dumps(event) self.data_channel.send(message_json) @@ -442,7 +467,9 @@ async def _send_event(self, event: dict): except Exception as e: logger.error(f"Failed to send event: {e}") - async def start_video_sender(self, stream_video_track: MediaStreamTrack, fps: int = 1, shared_forwarder=None) -> None: + async def start_video_sender( + self, stream_video_track: MediaStreamTrack, fps: int = 1, shared_forwarder=None + ) -> None: """Replace dummy video track with the actual Stream Video forwarding track. This creates a forwarding track that reads frames from the Stream Video track @@ -453,15 +480,19 @@ async def start_video_sender(self, stream_video_track: MediaStreamTrack, fps: in fps: Target frames per second. shared_forwarder: Optional shared VideoForwarder to use instead of creating a new one. """ - + try: if not self.send_video: logger.error("❌ Video sending not enabled for this session") raise RuntimeError("Video sending not enabled for this session") if self._video_sender is None: - logger.error("❌ Video sender not available; was video track negotiated?") - raise RuntimeError("Video sender not available; was video track negotiated?") - + logger.error( + "❌ Video sender not available; was video track negotiated?" + ) + raise RuntimeError( + "Video sender not available; was video track negotiated?" + ) + # Stop any existing video sender task if self._video_sender_task is not None: logger.info("🎥 Stopping existing video sender task...") @@ -471,23 +502,29 @@ async def start_video_sender(self, stream_video_track: MediaStreamTrack, fps: in except asyncio.CancelledError: pass logger.info("🎥 Existing video sender task stopped") - + # Create forwarding track and start its forwarder - forwarding_track = StreamVideoForwardingTrack(stream_video_track, fps, shared_forwarder=shared_forwarder) + forwarding_track = StreamVideoForwardingTrack( + stream_video_track, fps, shared_forwarder=shared_forwarder + ) await forwarding_track.start() - + # Replace the dummy track with the forwarding track try: - logger.info("🎥 Replacing OpenAI dummy track with StreamVideoForwardingTrack") + logger.info( + "🎥 Replacing OpenAI dummy track with StreamVideoForwardingTrack" + ) self._video_sender.replaceTrack(forwarding_track) self._forwarding_track = forwarding_track self._active_video_source = stream_video_track - logger.info(f"✅ Successfully replaced OpenAI track with Stream Video forwarding (fps={fps})") + logger.info( + f"✅ Successfully replaced OpenAI track with Stream Video forwarding (fps={fps})" + ) except Exception as replace_error: logger.error(f"❌ Failed to replace video track: {replace_error}") logger.error(f"❌ Replace error type: {type(replace_error).__name__}") raise RuntimeError(f"Track replacement failed: {replace_error}") - + except Exception as e: logger.error(f"❌ Failed to start video sender: {e}") logger.error(f"❌ Error type: {type(e).__name__}") @@ -511,11 +548,11 @@ async def stop_video_sender(self) -> None: except Exception: pass self._forwarding_track = None - + if self._video_sender is None: logger.warning("No video sender available to stop") return - + # Replace track with proper error handling try: if self._video_track is None: @@ -524,13 +561,15 @@ async def stop_video_sender(self) -> None: logger.info("✅ Video sender detached (no base track)") else: self._video_sender.replaceTrack(self._video_track) - logger.info(f"✅ Video sender reverted to dummy track: {type(self._video_track).__name__}") - + logger.info( + f"✅ Video sender reverted to dummy track: {type(self._video_track).__name__}" + ) + self._active_video_source = None except Exception as replace_error: logger.error(f"❌ Failed to revert video track: {replace_error}") raise RuntimeError(f"Track reversion failed: {replace_error}") - + except Exception as e: logger.error(f"❌ Failed to stop video sender: {e}") raise @@ -557,7 +596,9 @@ async def _exchange_sdp(self, local_sdp: str) -> Optional[str]: try: async with AsyncClient() as client: - response = await client.post(url, headers=headers, content=local_sdp, timeout=20) + response = await client.post( + url, headers=headers, content=local_sdp, timeout=20 + ) response.raise_for_status() return response.text if response.text else None except HTTPStatusError as e: @@ -576,7 +617,10 @@ async def _handle_added_track(self, track: MediaStreamTrack) -> None: async def _reader(): while True: try: - frame: AudioFrame = cast(AudioFrame, await asyncio.wait_for(track.recv(), timeout=1.0)) + frame: AudioFrame = cast( + AudioFrame, + await asyncio.wait_for(track.recv(), timeout=1.0), + ) except asyncio.TimeoutError: continue except Exception as e: @@ -597,8 +641,8 @@ async def _reader(): await cb(audio_bytes) except Exception as e: logger.debug(f"Failed to process remote audio frame: {e}") + asyncio.create_task(_reader()) - async def _handle_event(self, event: dict) -> None: """Minimal event handler for data channel messages.""" @@ -610,7 +654,7 @@ async def _handle_event(self, event: dict) -> None: logger.debug(f"Event callback error: {e}") # Store session information when we receive session.created event - # FIXME Typing + # FIXME Typing if event.get("type") == "session.created" and "session" in event: self.session_info = event["session"] logger.debug(f"Stored session info: {self.session_info}") @@ -625,7 +669,9 @@ async def request_session_info(self) -> None: if self.session_info: logger.info(f"Current session info: {self.session_info}") else: - logger.info("No session information available yet. Waiting for session.created event.") + logger.info( + "No session information available yet. Waiting for session.created event." + ) def set_audio_callback(self, callback: Callable[[bytes], Any]) -> None: """Set callback for receiving audio data from OpenAI. @@ -643,51 +689,78 @@ def set_event_callback(self, callback: Callable[[dict], Any]) -> None: """ self._event_callback = callback - async def _forward_video_frames(self, source_track: MediaStreamTrack, fps: int) -> None: + async def _forward_video_frames( + self, source_track: MediaStreamTrack, fps: int + ) -> None: """Forward video frames from user's track to OpenAI via WebRTC. - + This method reads frames from the user's video track and forwards them through the WebRTC connection to OpenAI for processing. """ interval = max(0.01, 1.0 / max(1, fps)) frame_count = 0 - + try: - logger.info(f"🎥 Starting video frame forwarding loop (fps={fps}, interval={interval:.3f}s)") - logger.info(f"🎥 Source track: {type(source_track).__name__}, kind={getattr(source_track, 'kind', 'unknown')}") - + logger.info( + f"🎥 Starting video frame forwarding loop (fps={fps}, interval={interval:.3f}s)" + ) + logger.info( + f"🎥 Source track: {type(source_track).__name__}, kind={getattr(source_track, 'kind', 'unknown')}" + ) + while True: try: # Read frame from user's video track - logger.debug(f"🎥 Attempting to read frame #{frame_count + 1} from user track...") - frame: VideoFrame = cast(VideoFrame, await asyncio.wait_for(source_track.recv(), timeout=1.0)) + logger.debug( + f"🎥 Attempting to read frame #{frame_count + 1} from user track..." + ) + frame: VideoFrame = cast( + VideoFrame, + await asyncio.wait_for(source_track.recv(), timeout=1.0), + ) frame_count += 1 - + # Log frame details - logger.info(f"🎥 SUCCESS: Read frame #{frame_count} from user track!") - logger.info(f"🎥 Frame details: {frame.width}x{frame.height}, format={frame.format}, pts={frame.pts}") - + logger.info( + f"🎥 SUCCESS: Read frame #{frame_count} from user track!" + ) + logger.info( + f"🎥 Frame details: {frame.width}x{frame.height}, format={frame.format}, pts={frame.pts}" + ) + # The frame is automatically forwarded through the WebRTC connection # since we replaced the track with replaceTrack() - logger.debug(f"🎥 Frame #{frame_count} automatically forwarded via WebRTC") - + logger.debug( + f"🎥 Frame #{frame_count} automatically forwarded via WebRTC" + ) + # Throttle frame rate await asyncio.sleep(interval) - + except asyncio.TimeoutError: - logger.warning(f"🎥 Timeout waiting for frame #{frame_count + 1} from user track") + logger.warning( + f"🎥 Timeout waiting for frame #{frame_count + 1} from user track" + ) continue except Exception as e: - logger.error(f"❌ Error reading video frame #{frame_count + 1}: {e}") + logger.error( + f"❌ Error reading video frame #{frame_count + 1}: {e}" + ) logger.error(f"❌ Exception type: {type(e).__name__}") break - + except asyncio.CancelledError: - logger.info(f"🎥 Video forwarding task cancelled after {frame_count} frames") + logger.info( + f"🎥 Video forwarding task cancelled after {frame_count} frames" + ) except Exception as e: - logger.error(f"❌ Video forwarding task failed after {frame_count} frames: {e}") + logger.error( + f"❌ Video forwarding task failed after {frame_count} frames: {e}" + ) finally: - logger.info(f"🎥 Video forwarding task ended. Total frames processed: {frame_count}") + logger.info( + f"🎥 Video forwarding task ended. Total frames processed: {frame_count}" + ) async def close(self) -> None: """Close the WebRTC connection and clean up resources.""" @@ -699,7 +772,7 @@ async def close(self) -> None: await self._video_sender_task except asyncio.CancelledError: pass - + if self.data_channel is not None: try: self.data_channel.close() diff --git a/plugins/silero/tests/test_vad.py b/plugins/silero/tests/test_vad.py index 67d3d167..526b8c91 100644 --- a/plugins/silero/tests/test_vad.py +++ b/plugins/silero/tests/test_vad.py @@ -446,19 +446,19 @@ async def test_silence_no_turns(): async def on_audio(event: VADAudioEvent): nonlocal audio_event_fired audio_event_fired = True - duration_sec = (event.duration_ms / 1000.0) if event.duration_ms is not None else 0.0 - logger.info( - f"Audio event detected on silence! Duration: {duration_sec:.2f}s" + duration_sec = ( + (event.duration_ms / 1000.0) if event.duration_ms is not None else 0.0 ) + logger.info(f"Audio event detected on silence! Duration: {duration_sec:.2f}s") @vad.events.subscribe async def on_partial(event: VADPartialEvent): nonlocal partial_event_fired partial_event_fired = True - duration_sec = (event.duration_ms / 1000.0) if event.duration_ms is not None else 0.0 - logger.info( - f"Partial event detected on silence! Duration: {duration_sec:.2f}s" + duration_sec = ( + (event.duration_ms / 1000.0) if event.duration_ms is not None else 0.0 ) + logger.info(f"Partial event detected on silence! Duration: {duration_sec:.2f}s") # Process the silence in chunks to simulate streaming chunk_size = 512 diff --git a/plugins/silero/vision_agents/plugins/silero/vad.py b/plugins/silero/vision_agents/plugins/silero/vad.py index 20ced788..03a33189 100644 --- a/plugins/silero/vision_agents/plugins/silero/vad.py +++ b/plugins/silero/vision_agents/plugins/silero/vad.py @@ -375,19 +375,21 @@ async def is_speech(self, frame: PcmData) -> float: # Update current speech probability self._current_speech_probability = speech_prob - self.events.send(vad.events.VADInferenceEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - speech_probability=speech_prob, - inference_time_ms=inference_time, - window_samples=self.window_samples, - model_rate=self.model_rate, - real_time_factor=rtf, - is_speech_active=self.is_speech_active, - accumulated_speech_duration_ms=self._get_accumulated_speech_duration(), - accumulated_silence_duration_ms=self._get_accumulated_silence_duration(), - user_metadata=None, # Will be set by caller if needed - )) + self.events.send( + vad.events.VADInferenceEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + speech_probability=speech_prob, + inference_time_ms=inference_time, + window_samples=self.window_samples, + model_rate=self.model_rate, + real_time_factor=rtf, + is_speech_active=self.is_speech_active, + accumulated_speech_duration_ms=self._get_accumulated_speech_duration(), + accumulated_silence_duration_ms=self._get_accumulated_silence_duration(), + user_metadata=None, # Will be set by caller if needed + ) + ) # Log speech probability and RTF at DEBUG level logger.debug( @@ -414,7 +416,9 @@ async def is_speech(self, frame: PcmData) -> float: # On error, return low probability return 0.0 - async def _flush_speech_buffer(self, user: Optional[Union[Dict[str, Any], "Participant"]] = None) -> None: + async def _flush_speech_buffer( + self, user: Optional[Union[Dict[str, Any], "Participant"]] = None + ) -> None: """ Flush the accumulated speech buffer if it meets minimum length requirements. @@ -440,31 +444,35 @@ async def _flush_speech_buffer(self, user: Optional[Union[Dict[str, Any], "Parti # Calculate average speech probability during this segment avg_speech_prob = self._get_avg_speech_probability() - self.events.send(vad.events.VADAudioEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - audio_data=speech_data.tobytes(), - sample_rate=self.sample_rate, - audio_format=vad.events.AudioFormat.PCM_S16, - channels=1, - duration_ms=duration_ms, - speech_probability=avg_speech_prob, - frame_count=len(speech_data) // self.frame_size, - user_metadata=user, - )) + self.events.send( + vad.events.VADAudioEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + audio_data=speech_data.tobytes(), + sample_rate=self.sample_rate, + audio_format=vad.events.AudioFormat.PCM_S16, + channels=1, + duration_ms=duration_ms, + speech_probability=avg_speech_prob, + frame_count=len(speech_data) // self.frame_size, + user_metadata=user, + ) + ) # Emit speech end event if we were actively detecting speech if self.is_speech_active and self._speech_start_time: total_speech_duration = (time.time() - self._speech_start_time) * 1000 - self.events.send(vad.events.VADSpeechEndEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - speech_probability=self._speech_end_probability, - deactivation_threshold=self.deactivation_th, - total_speech_duration_ms=total_speech_duration, - total_frames=self.total_speech_frames, - user_metadata=user, - )) + self.events.send( + vad.events.VADSpeechEndEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + speech_probability=self._speech_end_probability, + deactivation_threshold=self.deactivation_th, + total_speech_duration_ms=total_speech_duration, + total_frames=self.total_speech_frames, + user_metadata=user, + ) + ) # Reset state variables self.speech_buffer = bytearray() @@ -530,15 +538,17 @@ async def _process_frame( self._speech_start_probability = speech_prob self._speech_probabilities = [speech_prob] # Reset probability tracking - self.events.send(VADSpeechStartEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - speech_probability=speech_prob, - activation_threshold=self.activation_th, - frame_count=1, - user_metadata=user, - audio_data=frame - )) + self.events.send( + VADSpeechStartEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + speech_probability=speech_prob, + activation_threshold=self.activation_th, + frame_count=1, + user_metadata=user, + audio_data=frame, + ) + ) # Add this frame to the buffer using shared utility from getstream.audio.pcm_utils import numpy_array_to_bytes @@ -568,19 +578,21 @@ async def _process_frame( # Calculate current duration current_duration_ms = (len(current_samples) / self.sample_rate) * 1000 - self.events.send(vad.events.VADPartialEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - audio_data=current_bytes, - sample_rate=self.sample_rate, - audio_format=AudioFormat.PCM_S16, - channels=1, - duration_ms=current_duration_ms, - speech_probability=speech_prob, - frame_count=len(current_samples) // self.frame_size, - is_speech_active=True, - user_metadata=user, - )) + self.events.send( + vad.events.VADPartialEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + audio_data=current_bytes, + sample_rate=self.sample_rate, + audio_format=AudioFormat.PCM_S16, + channels=1, + duration_ms=current_duration_ms, + speech_probability=speech_prob, + frame_count=len(current_samples) // self.frame_size, + is_speech_active=True, + user_metadata=user, + ) + ) self.partial_counter = 0 diff --git a/plugins/smart_turn/tests/test_turn_detection.py b/plugins/smart_turn/tests/test_turn_detection.py index 5387e45d..32f9df1f 100644 --- a/plugins/smart_turn/tests/test_turn_detection.py +++ b/plugins/smart_turn/tests/test_turn_detection.py @@ -28,4 +28,3 @@ async def test_turn_detection_integration(self): # This test should be run manually with a valid FAL_KEY # For now, just pass assert True - diff --git a/plugins/smart_turn/vision_agents/plugins/smart_turn/__init__.py b/plugins/smart_turn/vision_agents/plugins/smart_turn/__init__.py index 5bb06826..f5b2690b 100644 --- a/plugins/smart_turn/vision_agents/plugins/smart_turn/__init__.py +++ b/plugins/smart_turn/vision_agents/plugins/smart_turn/__init__.py @@ -4,4 +4,3 @@ __path__ = __import__("pkgutil").extend_path(__path__, __name__) __all__ = ["TurnDetection"] - diff --git a/plugins/smart_turn/vision_agents/plugins/smart_turn/turn_detection.py b/plugins/smart_turn/vision_agents/plugins/smart_turn/turn_detection.py index 4736696d..36dbf3b6 100644 --- a/plugins/smart_turn/vision_agents/plugins/smart_turn/turn_detection.py +++ b/plugins/smart_turn/vision_agents/plugins/smart_turn/turn_detection.py @@ -62,7 +62,8 @@ def __init__( """ super().__init__( - confidence_threshold=confidence_threshold, provider_name="SmartTurnDetection" + confidence_threshold=confidence_threshold, + provider_name="SmartTurnDetection", ) self.logger = logging.getLogger("SmartTurnDetection") self.api_key = api_key @@ -96,9 +97,7 @@ def _infer_channels(self, format_str: str) -> int: elif any(f in format_str for f in ["mono", "s16", "int16", "pcm_s16le"]): return 1 else: - self.logger.warning( - f"Unknown format string: {format_str}. Assuming mono." - ) + self.logger.warning(f"Unknown format string: {format_str}. Assuming mono.") return 1 def is_detecting(self) -> bool: @@ -379,4 +378,3 @@ def stop(self) -> None: self.logger.warning(f"Failed to clean up temp files: {e}") self.logger.info("Smart Turn detection stopped") - diff --git a/plugins/ultralytics/tests/test_ultralytics.py b/plugins/ultralytics/tests/test_ultralytics.py index 756126d3..c627fd43 100644 --- a/plugins/ultralytics/tests/test_ultralytics.py +++ b/plugins/ultralytics/tests/test_ultralytics.py @@ -35,14 +35,18 @@ def pose_processor(self) -> Iterator[YOLOPoseProcessor]: finally: processor.close() - async def test_annotated_ndarray(self, golf_image: Image.Image, pose_processor: YOLOPoseProcessor): + async def test_annotated_ndarray( + self, golf_image: Image.Image, pose_processor: YOLOPoseProcessor + ): frame_array = np.array(golf_image) array_with_pose, pose = await pose_processor.add_pose_to_ndarray(frame_array) assert array_with_pose is not None assert pose is not None - async def test_annotated_image_output(self, golf_image: Image.Image, pose_processor: YOLOPoseProcessor): + async def test_annotated_image_output( + self, golf_image: Image.Image, pose_processor: YOLOPoseProcessor + ): image_with_pose, pose = await pose_processor.add_pose_to_image(image=golf_image) assert image_with_pose is not None @@ -50,37 +54,37 @@ async def test_annotated_image_output(self, golf_image: Image.Image, pose_proces # Ensure same size as input for simplicity assert image_with_pose.size == golf_image.size - + # Save the annotated image temporarily for inspection temp_path = Path("/tmp/annotated_golf_swing.png") image_with_pose.save(temp_path) print(f"Saved annotated image to: {temp_path}") - async def test_annotated_frame_output(self, golf_image: Image.Image, pose_processor: YOLOPoseProcessor): + async def test_annotated_frame_output( + self, golf_image: Image.Image, pose_processor: YOLOPoseProcessor + ): """Test add_pose_to_frame method with av.VideoFrame input.""" # Convert PIL Image to av.VideoFrame frame = av.VideoFrame.from_image(golf_image) - + # Process the frame with pose detection frame_with_pose = await pose_processor.add_pose_to_frame(frame) - + # Verify the result is an av.VideoFrame assert frame_with_pose is not None assert isinstance(frame_with_pose, av.VideoFrame) - + # Verify dimensions are preserved assert frame_with_pose.width == frame.width assert frame_with_pose.height == frame.height - + # Convert back to numpy array to verify it's been processed result_array = frame_with_pose.to_ndarray() original_array = frame.to_ndarray() - + # The arrays should be the same shape assert result_array.shape == original_array.shape - + # The processed frame should be different from the original (pose annotations added) # Note: This might not always be true if no pose is detected, but it's a reasonable check assert not np.array_equal(result_array, original_array) - - diff --git a/plugins/wizper/vision_agents/plugins/wizper/stt.py b/plugins/wizper/vision_agents/plugins/wizper/stt.py index 60ca6c79..4bb638bd 100644 --- a/plugins/wizper/vision_agents/plugins/wizper/stt.py +++ b/plugins/wizper/vision_agents/plugins/wizper/stt.py @@ -99,7 +99,9 @@ def _pcm_to_wav_bytes(self, pcm_data: PcmData) -> bytes: return wav_buffer.read() async def _process_audio_impl( - self, pcm_data: PcmData, user_metadata: Optional[Union[Dict[str, Any], "Participant"]] = None + self, + pcm_data: PcmData, + user_metadata: Optional[Union[Dict[str, Any], "Participant"]] = None, ) -> Optional[List[Tuple[bool, str, Dict[str, Any]]]]: """ Process accumulated speech audio through fal-ai/wizper. diff --git a/plugins/xai/vision_agents/plugins/xai/events.py b/plugins/xai/vision_agents/plugins/xai/events.py index 07c9b4f1..d22ec7a9 100644 --- a/plugins/xai/vision_agents/plugins/xai/events.py +++ b/plugins/xai/vision_agents/plugins/xai/events.py @@ -6,5 +6,6 @@ @dataclass class XAIChunkEvent(PluginBaseEvent): """Event emitted when xAI provides a chunk.""" - type: str = field(default='plugin.xai.chunk', init=False) + + type: str = field(default="plugin.xai.chunk", init=False) chunk: Optional[Any] = None diff --git a/plugins/xai/vision_agents/plugins/xai/llm.py b/plugins/xai/vision_agents/plugins/xai/llm.py index 5392ab64..d7d55422 100644 --- a/plugins/xai/vision_agents/plugins/xai/llm.py +++ b/plugins/xai/vision_agents/plugins/xai/llm.py @@ -5,7 +5,10 @@ from vision_agents.core.llm.llm import LLM, LLMResponseEvent from vision_agents.core.processors import Processor -from vision_agents.core.llm.events import LLMResponseChunkEvent, LLMResponseCompletedEvent +from vision_agents.core.llm.events import ( + LLMResponseChunkEvent, + LLMResponseCompletedEvent, +) from . import events if TYPE_CHECKING: @@ -91,7 +94,9 @@ async def simple_response( instructions=instructions, ) - async def create_response(self, *args: Any, **kwargs: Any) -> LLMResponseEvent[Response]: + async def create_response( + self, *args: Any, **kwargs: Any + ) -> LLMResponseEvent[Response]: """ create_response gives you full support/access to the native xAI chat.sample() and chat.stream() methods this method wraps the xAI method and ensures we broadcast an event which the agent class hooks into @@ -139,10 +144,11 @@ async def create_response(self, *args: Any, **kwargs: Any) -> LLMResponseEvent[R self.xai_chat.append(response) if llm_response is not None: - self.events.send(LLMResponseCompletedEvent( - original=llm_response.original, - text=llm_response.text - )) + self.events.send( + LLMResponseCompletedEvent( + original=llm_response.original, text=llm_response.text + ) + ) return llm_response or LLMResponseEvent[Response]( Response(chat_pb2.GetChatCompletionResponse(), 0), "" @@ -170,31 +176,32 @@ def _standardize_and_emit_chunk( Forwards the chunk events and also send out a standardized version (the agent class hooks into that) """ # Emit the raw chunk event - self.events.send(events.XAIChunkEvent( - plugin_name="xai", - chunk=chunk - )) + self.events.send(events.XAIChunkEvent(plugin_name="xai", chunk=chunk)) # Emit standardized delta events for content if chunk.content: - self.events.send(LLMResponseChunkEvent( - content_index=0, # xAI doesn't have content_index - item_id=chunk.proto.id if hasattr(chunk.proto, "id") else "", - output_index=0, # xAI doesn't have output_index - sequence_number=0, # xAI doesn't have sequence_number - delta=chunk.content, - plugin_name="xai", - )) + self.events.send( + LLMResponseChunkEvent( + content_index=0, # xAI doesn't have content_index + item_id=chunk.proto.id if hasattr(chunk.proto, "id") else "", + output_index=0, # xAI doesn't have output_index + sequence_number=0, # xAI doesn't have sequence_number + delta=chunk.content, + plugin_name="xai", + ) + ) # Check if this is the final chunk (finish_reason indicates completion) if chunk.choices and chunk.choices[0].finish_reason: # This is the final chunk, return the complete response llm_response = LLMResponseEvent[Response](response, response.content) - self.events.send(LLMResponseCompletedEvent( - plugin_name="xai", - text=llm_response.text, - original=llm_response.original - )) + self.events.send( + LLMResponseCompletedEvent( + plugin_name="xai", + text=llm_response.text, + original=llm_response.original, + ) + ) return llm_response return None diff --git a/tests/base_test.py b/tests/base_test.py index 9faf2f79..c66364be 100644 --- a/tests/base_test.py +++ b/tests/base_test.py @@ -13,7 +13,7 @@ class BaseTest: def assets_dir(self): """Get the test assets directory path.""" return os.path.join(os.path.dirname(__file__), "test_assets") - + @pytest.fixture def mia_audio_16khz(self): audio_file_path = os.path.join(os.path.dirname(__file__), "test_assets/mia.mp3") @@ -27,11 +27,7 @@ def mia_audio_16khz(self): # Create resampler if needed resampler = None if original_sample_rate != target_rate: - resampler = av.AudioResampler( - format='s16', - layout='mono', - rate=target_rate - ) + resampler = av.AudioResampler(format="s16", layout="mono", rate=target_rate) # Read all audio frames samples = [] @@ -55,11 +51,7 @@ def mia_audio_16khz(self): container.close() # Create PCM data - pcm = PcmData( - samples=samples, - sample_rate=target_rate, - format="s16" - ) + pcm = PcmData(samples=samples, sample_rate=target_rate, format="s16") return pcm @@ -67,7 +59,10 @@ def mia_audio_16khz(self): def bunny_video_track(self): """Create RealVideoTrack from video file""" from aiortc import VideoStreamTrack - video_file_path = os.path.join(os.path.dirname(__file__), "test_assets/bunny_3s.mp4") + + video_file_path = os.path.join( + os.path.dirname(__file__), "test_assets/bunny_3s.mp4" + ) class RealVideoTrack(VideoStreamTrack): def __init__(self, video_path, max_frames=None): @@ -87,16 +82,16 @@ async def recv(self): for frame in self.container.decode(self.video_stream): if frame is None: raise asyncio.CancelledError("End of video stream") - + self.frame_count += 1 # Convert to RGB frame = frame.to_rgb() - + # Sleep for realistic video timing await asyncio.sleep(self.frame_duration) - + return frame - + # If we get here, we've exhausted all frames in the stream raise asyncio.CancelledError("End of video stream") @@ -111,4 +106,4 @@ async def recv(self): print(f"Error reading video frame: {e}") raise asyncio.CancelledError("Video read error") - return RealVideoTrack(video_file_path, max_frames=None) \ No newline at end of file + return RealVideoTrack(video_file_path, max_frames=None) diff --git a/tests/test_conversation.py b/tests/test_conversation.py index 2731a2c7..0e0c1d8f 100644 --- a/tests/test_conversation.py +++ b/tests/test_conversation.py @@ -16,28 +16,32 @@ Message, InMemoryConversation, # StreamConversation, # Removed from codebase - StreamHandle + StreamHandle, ) # Skip entire module - StreamConversation class has been removed from codebase # TODO: Update tests to use new conversation architecture -pytestmark = pytest.mark.skip(reason="StreamConversation class removed - tests need migration to new architecture") +pytestmark = pytest.mark.skip( + reason="StreamConversation class removed - tests need migration to new architecture" +) + class TestConversation: """Test suite for the abstract Conversation class.""" - + def test_conversation_is_abstract(self): """Test that Conversation cannot be instantiated directly.""" with pytest.raises(TypeError) as exc_info: Conversation("instructions", []) assert "Can't instantiate abstract class" in str(exc_info.value) - + def test_conversation_requires_abstract_methods(self): """Test that subclasses must implement abstract methods.""" + class IncompleteConversation(Conversation): # Missing implementation of abstract methods pass - + with pytest.raises(TypeError) as exc_info: IncompleteConversation("instructions", []) assert "Can't instantiate abstract class" in str(exc_info.value) @@ -45,16 +49,16 @@ class IncompleteConversation(Conversation): class TestMessage: """Test suite for the Message dataclass.""" - + def test_message_initialization(self): """Test that Message initializes correctly with default timestamp.""" message = Message( original={"role": "user", "content": "Hello"}, content="Hello", role="user", - user_id="test-user" + user_id="test-user", ) - + assert message.content == "Hello" assert message.role == "user" assert message.user_id == "test-user" @@ -64,69 +68,65 @@ def test_message_initialization(self): class TestInMemoryConversation: """Test suite for InMemoryConversation class.""" - + @pytest.fixture def conversation(self): """Create a basic InMemoryConversation instance.""" instructions = "You are a helpful assistant." messages = [ Message(original=None, content="Hello", role="user", user_id="user1"), - Message(original=None, content="Hi there!", role="assistant", user_id="assistant") + Message( + original=None, + content="Hi there!", + role="assistant", + user_id="assistant", + ), ] # Set IDs for messages for i, msg in enumerate(messages): msg.id = f"msg-{i}" return InMemoryConversation(instructions, messages) - + def test_initialization(self, conversation): """Test InMemoryConversation initialization.""" assert conversation.instructions == "You are a helpful assistant." assert len(conversation.messages) == 2 - + def test_add_message(self, conversation): """Test adding a single message.""" new_message = Message( - original=None, - content="New message", - role="user", - user_id="user2" + original=None, content="New message", role="user", user_id="user2" ) new_message.id = "new-msg" conversation.add_message(new_message) - + assert len(conversation.messages) == 3 assert conversation.messages[-1] == new_message - + def test_add_message_with_completed(self, conversation): """Test adding a message with completed parameter.""" # Test with completed=False new_message1 = Message( - original=None, - content="Generating message", - role="user", - user_id="user2" + original=None, content="Generating message", role="user", user_id="user2" ) new_message1.id = "gen-msg" result = conversation.add_message(new_message1, completed=False) - + assert len(conversation.messages) == 3 assert conversation.messages[-1] == new_message1 assert result is None # InMemoryConversation returns None - + # Test with completed=True (default) new_message2 = Message( - original=None, - content="Complete message", - role="user", - user_id="user3" + original=None, content="Complete message", role="user", user_id="user3" ) new_message2.id = "comp-msg" result = conversation.add_message(new_message2, completed=True) - + assert len(conversation.messages) == 4 assert conversation.messages[-1] == new_message2 assert result is None - + def test_update_message_existing(self, conversation): """Test updating an existing message by appending content.""" # Update existing message by appending (replace_content=False) @@ -135,13 +135,13 @@ def test_update_message_existing(self, conversation): input_text=" additional text", user_id="user1", replace_content=False, - completed=False + completed=False, ) - + # Verify message content was appended (with space handling) assert conversation.messages[0].content == "Hello additional text" assert result is None # InMemoryConversation returns None - + def test_update_message_replace(self, conversation): """Test replacing message content (replace_content=True).""" result = conversation.update_message( @@ -149,106 +149,116 @@ def test_update_message_replace(self, conversation): input_text="Replaced content", user_id="user1", replace_content=True, - completed=True + completed=True, ) - + # Verify message content was replaced assert conversation.messages[0].content == "Replaced content" assert result is None - + def test_update_message_not_found(self, conversation): """Test updating a non-existent message creates a new one.""" initial_count = len(conversation.messages) - + conversation.update_message( message_id="non-existent-id", input_text="New message content", user_id="user2", replace_content=True, - completed=False + completed=False, ) - + # Should have added a new message assert len(conversation.messages) == initial_count + 1 - + # Verify the new message was created correctly new_msg = conversation.messages[-1] assert new_msg.id == "non-existent-id" assert new_msg.content == "New message content" assert new_msg.user_id == "user2" - + def test_streaming_message_handle(self, conversation): """Test streaming message with handle API.""" # Start a streaming message - handle = conversation.start_streaming_message(role="assistant", initial_content="Hello") - + handle = conversation.start_streaming_message( + role="assistant", initial_content="Hello" + ) + # Verify message was added assert len(conversation.messages) == 3 assert conversation.messages[-1].content == "Hello" assert conversation.messages[-1].role == "assistant" assert isinstance(handle, StreamHandle) assert handle.user_id == "assistant" - + # Append to the message conversation.append_to_message(handle, " world") assert conversation.messages[-1].content == "Hello world" - + # Replace the message conversation.replace_message(handle, "Goodbye") assert conversation.messages[-1].content == "Goodbye" - + # Complete the message conversation.complete_message(handle) # In-memory conversation doesn't track completed state, just verify no error - + def test_multiple_streaming_handles(self, conversation): """Test multiple concurrent streaming messages.""" # Start two streaming messages - handle1 = conversation.start_streaming_message(role="user", user_id="user1", initial_content="Question: ") - handle2 = conversation.start_streaming_message(role="assistant", initial_content="Answer: ") - + handle1 = conversation.start_streaming_message( + role="user", user_id="user1", initial_content="Question: " + ) + handle2 = conversation.start_streaming_message( + role="assistant", initial_content="Answer: " + ) + assert len(conversation.messages) == 4 # 2 initial + 2 new - + # Update them independently conversation.append_to_message(handle1, "What is 2+2?") conversation.append_to_message(handle2, "Let me calculate...") - + # Find messages by their handles to verify correct updates - msg1 = next(msg for msg in conversation.messages if msg.id == handle1.message_id) - msg2 = next(msg for msg in conversation.messages if msg.id == handle2.message_id) - + msg1 = next( + msg for msg in conversation.messages if msg.id == handle1.message_id + ) + msg2 = next( + msg for msg in conversation.messages if msg.id == handle2.message_id + ) + assert msg1.content == "Question: What is 2+2?" assert msg2.content == "Answer: Let me calculate..." - + # Complete them conversation.complete_message(handle1) conversation.replace_message(handle2, "Answer: 4") conversation.complete_message(handle2) - + assert msg2.content == "Answer: 4" # Replaced content, no space issue class TestStreamConversation: """Test suite for StreamConversation class.""" - + @pytest.fixture def mock_chat_client(self): """Create a mock ChatClient.""" client = Mock(spec=ChatClient) - + # Mock send_message response mock_response = Mock() mock_response.data.message.id = "stream-message-123" client.send_message.return_value = mock_response - + # Mock ephemeral_message_update client.ephemeral_message_update = Mock(return_value=Mock()) - + # Mock update_message_partial client.update_message_partial = Mock(return_value=Mock()) - + return client - + @pytest.fixture def mock_channel(self): """Create a mock ChannelResponse.""" @@ -256,7 +266,7 @@ def mock_channel(self): channel.type = "messaging" channel.id = "test-channel-123" return channel - + @pytest.fixture def stream_conversation(self, mock_chat_client, mock_channel): """Create a StreamConversation instance with mocked dependencies.""" @@ -272,119 +282,123 @@ def stream_conversation(self, mock_chat_client, mock_channel): # Set IDs for messages for i, msg in enumerate(messages): msg.id = f"msg-{i}" - + conversation = StreamConversation( # noqa: F821 instructions=instructions, messages=messages, channel=mock_channel, - chat_client=mock_chat_client + chat_client=mock_chat_client, ) - + # Pre-populate some stream IDs for testing - conversation.internal_ids_to_stream_ids = { - "msg-0": "stream-msg-0" - } - + conversation.internal_ids_to_stream_ids = {"msg-0": "stream-msg-0"} + yield conversation - + # Cleanup after each test conversation.shutdown() - + def test_initialization(self, stream_conversation, mock_channel, mock_chat_client): """Test StreamConversation initialization.""" assert stream_conversation.channel == mock_channel assert stream_conversation.chat_client == mock_chat_client assert isinstance(stream_conversation.internal_ids_to_stream_ids, dict) assert len(stream_conversation.messages) == 1 - + def test_add_message(self, stream_conversation, mock_chat_client): """Test adding a message to the stream with default completed=True.""" new_message = Message( - original=None, - content="Test message", - role="user", - user_id="user123" + original=None, content="Test message", role="user", user_id="user123" ) new_message.id = "new-msg-id" - + stream_conversation.add_message(new_message) - + # Verify message was added locally immediately assert len(stream_conversation.messages) == 2 assert stream_conversation.messages[-1] == new_message - + # Wait for async operations to complete assert stream_conversation.wait_for_pending_operations(timeout=2.0) - + # Verify Stream API was called mock_chat_client.send_message.assert_called_once() call_args = mock_chat_client.send_message.call_args assert call_args[0][0] == "messaging" # channel type assert call_args[0][1] == "test-channel-123" # channel id - + request = call_args[0][2] assert isinstance(request, MessageRequest) assert request.text == "Test message" assert request.user_id == "user123" - + # Verify ID mapping was stored assert "new-msg-id" in stream_conversation.internal_ids_to_stream_ids - assert stream_conversation.internal_ids_to_stream_ids["new-msg-id"] == "stream-message-123" - + assert ( + stream_conversation.internal_ids_to_stream_ids["new-msg-id"] + == "stream-message-123" + ) + # Wait a bit more for the update operation to complete time.sleep(0.1) - + # Verify update_message_partial was called (completed=True is default) mock_chat_client.update_message_partial.assert_called_once() update_args = mock_chat_client.update_message_partial.call_args assert update_args[0][0] == "stream-message-123" assert update_args[1]["user_id"] == "user123" assert update_args[1]["set"]["text"] == "Test message" - assert update_args[1]["set"]["generating"] is False # completed=True means not generating - - def test_add_message_with_completed_false(self, stream_conversation, mock_chat_client): + assert ( + update_args[1]["set"]["generating"] is False + ) # completed=True means not generating + + def test_add_message_with_completed_false( + self, stream_conversation, mock_chat_client + ): """Test adding a message with completed=False (still generating).""" # Ensure previous operations are complete stream_conversation.wait_for_pending_operations(timeout=1.0) - + # Reset mocks mock_chat_client.send_message.reset_mock() mock_chat_client.ephemeral_message_update.reset_mock() mock_chat_client.update_message_partial.reset_mock() - + new_message = Message( original=None, content="Generating message", role="assistant", - user_id="assistant" + user_id="assistant", ) new_message.id = "gen-msg-id" - + stream_conversation.add_message(new_message, completed=False) - + # Verify message was added locally assert len(stream_conversation.messages) == 2 assert stream_conversation.messages[-1] == new_message - + # Wait for async operations to complete assert stream_conversation.wait_for_pending_operations(timeout=2.0) - + # Verify Stream API was called mock_chat_client.send_message.assert_called_once() - + # Give a bit more time for the update operation to be queued and processed time.sleep(0.2) - + # Verify ephemeral_message_update was called (completed=False) mock_chat_client.ephemeral_message_update.assert_called_once() mock_chat_client.update_message_partial.assert_not_called() - + update_args = mock_chat_client.ephemeral_message_update.call_args assert update_args[0][0] == "stream-message-123" assert update_args[1]["user_id"] == "assistant" assert update_args[1]["set"]["text"] == "Generating message" - assert update_args[1]["set"]["generating"] is True # completed=False means still generating - + assert ( + update_args[1]["set"]["generating"] is True + ) # completed=False means still generating + def test_update_message_existing(self, stream_conversation, mock_chat_client): """Test updating an existing message by appending content.""" # Update existing message by appending (replace_content=False, completed=False) @@ -393,42 +407,44 @@ def test_update_message_existing(self, stream_conversation, mock_chat_client): input_text=" additional text", user_id="user1", replace_content=False, - completed=False + completed=False, ) - + # Verify message content was appended immediately assert stream_conversation.messages[0].content == "Hello additional text" - + # Wait for async operations to complete assert stream_conversation.wait_for_pending_operations(timeout=2.0) - + # Verify Stream API was called with ephemeral_message_update (not completed) mock_chat_client.ephemeral_message_update.assert_called_once() call_args = mock_chat_client.ephemeral_message_update.call_args assert call_args[0][0] == "stream-msg-0" # stream message ID assert call_args[1]["user_id"] == "user1" assert call_args[1]["set"]["text"] == "Hello additional text" - assert call_args[1]["set"]["generating"] is True # not completed = still generating - + assert ( + call_args[1]["set"]["generating"] is True + ) # not completed = still generating + def test_update_message_replace(self, stream_conversation, mock_chat_client): """Test replacing message content (replace_content=True).""" # Mock update_message_partial for completed messages mock_chat_client.update_message_partial = Mock(return_value=Mock()) - + stream_conversation.update_message( message_id="msg-0", input_text="Replaced content", user_id="user1", replace_content=True, - completed=True + completed=True, ) - + # Verify message content was replaced assert stream_conversation.messages[0].content == "Replaced content" - + # Wait for async operations to complete assert stream_conversation.wait_for_pending_operations(timeout=2.0) - + # Verify Stream API was called with update_message_partial (completed) mock_chat_client.update_message_partial.assert_called_once() call_args = mock_chat_client.update_message_partial.call_args @@ -436,269 +452,281 @@ def test_update_message_replace(self, stream_conversation, mock_chat_client): assert call_args[1]["user_id"] == "user1" assert call_args[1]["set"]["text"] == "Replaced content" assert call_args[1]["set"]["generating"] is False # completed = not generating - + def test_update_message_not_found(self, stream_conversation, mock_chat_client): """Test updating a non-existent message creates a new one.""" # Reset the send_message mock for this test mock_chat_client.send_message.reset_mock() - + stream_conversation.update_message( message_id="non-existent-id", input_text="New message content", user_id="user2", replace_content=True, - completed=False + completed=False, ) - + # Should have added a new message assert len(stream_conversation.messages) == 2 - + # Verify the new message was created correctly new_msg = stream_conversation.messages[-1] assert new_msg.id == "non-existent-id" assert new_msg.content == "New message content" assert new_msg.user_id == "user2" - + # Wait for async operations to complete assert stream_conversation.wait_for_pending_operations(timeout=2.0) time.sleep(0.2) # Give extra time for update operation - + # Verify send_message was called (not update) mock_chat_client.send_message.assert_called_once() - - def test_update_message_completed_vs_generating(self, stream_conversation, mock_chat_client): + + def test_update_message_completed_vs_generating( + self, stream_conversation, mock_chat_client + ): """Test that completed=True calls update_message_partial and completed=False calls ephemeral_message_update.""" # Mock update_message_partial for completed messages mock_chat_client.update_message_partial = Mock(return_value=Mock()) - + # Test with completed=False (still generating) stream_conversation.update_message( message_id="msg-0", input_text=" in progress", user_id="user1", replace_content=False, - completed=False + completed=False, ) - + # Wait for async operations assert stream_conversation.wait_for_pending_operations(timeout=2.0) - + # Should call ephemeral_message_update mock_chat_client.ephemeral_message_update.assert_called() mock_chat_client.update_message_partial.assert_not_called() - + # Reset mocks mock_chat_client.ephemeral_message_update.reset_mock() - + # Test with completed=True stream_conversation.update_message( message_id="msg-0", input_text="Final content", user_id="user1", replace_content=True, - completed=True + completed=True, ) - + # Wait for async operations assert stream_conversation.wait_for_pending_operations(timeout=2.0) - + # Should call update_message_partial mock_chat_client.update_message_partial.assert_called_once() mock_chat_client.ephemeral_message_update.assert_not_called() - + def test_update_message_no_stream_id(self, stream_conversation, mock_chat_client): """Test updating a message without a stream ID mapping.""" # Add a message without stream ID mapping - new_msg = Message( - original=None, - content="Test", - role="user", - user_id="user3" - ) + new_msg = Message(original=None, content="Test", role="user", user_id="user3") new_msg.id = "unmapped-msg" stream_conversation.messages.append(new_msg) - + # Try to update it by appending stream_conversation.update_message( message_id="unmapped-msg", input_text=" updated", user_id="user3", replace_content=False, - completed=False + completed=False, ) - + # Message should still be updated locally (with space handling) assert stream_conversation.messages[-1].content == "Test updated" - + # Since there's no stream_id mapping, the API call should be skipped # This is the expected behavior - we don't sync messages without stream IDs mock_chat_client.ephemeral_message_update.assert_not_called() - + def test_streaming_message_handle(self, stream_conversation, mock_chat_client): """Test streaming message with handle API.""" # Reset mocks mock_chat_client.send_message.reset_mock() mock_chat_client.ephemeral_message_update.reset_mock() mock_chat_client.update_message_partial.reset_mock() - + # Start a streaming message - handle = stream_conversation.start_streaming_message(role="assistant", initial_content="Processing") - + handle = stream_conversation.start_streaming_message( + role="assistant", initial_content="Processing" + ) + # Verify message was added and marked as generating assert len(stream_conversation.messages) == 2 assert stream_conversation.messages[-1].content == "Processing" assert stream_conversation.messages[-1].role == "assistant" assert isinstance(handle, StreamHandle) assert handle.user_id == "assistant" - + # Wait for async operations assert stream_conversation.wait_for_pending_operations(timeout=2.0) time.sleep(0.2) # Give extra time for update operation - + # Verify send_message was called mock_chat_client.send_message.assert_called_once() # Verify ephemeral_message_update was called (completed=False by default) mock_chat_client.ephemeral_message_update.assert_called_once() - + # Reset for next operations mock_chat_client.ephemeral_message_update.reset_mock() - + # Append to the message stream_conversation.append_to_message(handle, "...") assert stream_conversation.messages[-1].content == "Processing..." - + # Wait for append operation to complete assert stream_conversation.wait_for_pending_operations(timeout=2.0) - + mock_chat_client.ephemeral_message_update.assert_called_once() - + # Replace the message stream_conversation.replace_message(handle, "Complete response") assert stream_conversation.messages[-1].content == "Complete response" - + # Wait for replace operation to complete assert stream_conversation.wait_for_pending_operations(timeout=2.0) - + assert mock_chat_client.ephemeral_message_update.call_count == 2 - + # Complete the message mock_chat_client.update_message_partial.reset_mock() stream_conversation.complete_message(handle) - + # Wait for complete operation assert stream_conversation.wait_for_pending_operations(timeout=2.0) - + mock_chat_client.update_message_partial.assert_called_once() - + def test_multiple_streaming_handles(self, stream_conversation, mock_chat_client): """Test multiple concurrent streaming messages with Stream API.""" # Reset mocks mock_chat_client.send_message.reset_mock() mock_chat_client.ephemeral_message_update.reset_mock() - + # Mock different message IDs for each send mock_response1 = Mock() mock_response1.data.message.id = "stream-msg-1" mock_response2 = Mock() mock_response2.data.message.id = "stream-msg-2" mock_chat_client.send_message.side_effect = [mock_response1, mock_response2] - + # Start two streaming messages with empty initial content - handle1 = stream_conversation.start_streaming_message(role="user", user_id="user123", initial_content="") - handle2 = stream_conversation.start_streaming_message(role="assistant", initial_content="") - + handle1 = stream_conversation.start_streaming_message( + role="user", user_id="user123", initial_content="" + ) + handle2 = stream_conversation.start_streaming_message( + role="assistant", initial_content="" + ) + assert len(stream_conversation.messages) == 3 # 1 initial + 2 new - + # Wait for initial operations to complete assert stream_conversation.wait_for_pending_operations(timeout=2.0) time.sleep(0.3) # Give extra time for update operations - + # Update them independently stream_conversation.append_to_message(handle1, "Hello?") stream_conversation.append_to_message(handle2, "Hi there!") - + # Wait for append operations to complete assert stream_conversation.wait_for_pending_operations(timeout=2.0) - + # Find messages by their handles to verify correct updates - msg1 = next(msg for msg in stream_conversation.messages if msg.id == handle1.message_id) - msg2 = next(msg for msg in stream_conversation.messages if msg.id == handle2.message_id) - + msg1 = next( + msg for msg in stream_conversation.messages if msg.id == handle1.message_id + ) + msg2 = next( + msg for msg in stream_conversation.messages if msg.id == handle2.message_id + ) + assert msg1.content == "Hello?" assert msg2.content == "Hi there!" - + # Verify ephemeral updates were called for both - assert mock_chat_client.ephemeral_message_update.call_count >= 4 # 2 initial + 2 appends - + assert ( + mock_chat_client.ephemeral_message_update.call_count >= 4 + ) # 2 initial + 2 appends + # Complete both stream_conversation.complete_message(handle1) stream_conversation.complete_message(handle2) - + # Wait for completion operations assert stream_conversation.wait_for_pending_operations(timeout=2.0) - + # Verify update_message_partial was called for both completions assert mock_chat_client.update_message_partial.call_count == 2 - - def test_worker_thread_async_operations(self, stream_conversation, mock_chat_client): + + def test_worker_thread_async_operations( + self, stream_conversation, mock_chat_client + ): """Test that operations are processed asynchronously by the worker thread.""" # Reset mocks mock_chat_client.send_message.reset_mock() mock_chat_client.ephemeral_message_update.reset_mock() - + # Add multiple messages quickly messages = [] for i in range(5): msg = Message( - original=None, - content=f"Message {i}", - role="user", - user_id=f"user{i}" + original=None, content=f"Message {i}", role="user", user_id=f"user{i}" ) messages.append(msg) stream_conversation.add_message(msg, completed=False) - + # Verify messages were added locally immediately assert len(stream_conversation.messages) == 6 # 1 initial + 5 new - + # Wait for all operations to complete assert stream_conversation.wait_for_pending_operations(timeout=3.0) - + # Give a bit more time for update operations time.sleep(0.5) - + # Verify all send_message calls were made assert mock_chat_client.send_message.call_count == 5 - + # Verify all ephemeral_message_update calls were made assert mock_chat_client.ephemeral_message_update.call_count >= 5 - - def test_wait_for_pending_operations_timeout(self, stream_conversation, mock_chat_client): + + def test_wait_for_pending_operations_timeout( + self, stream_conversation, mock_chat_client + ): """Test that wait_for_pending_operations returns False on timeout.""" # Make send_message block for a long time block_event = threading.Event() - + def slow_send_message(*args, **kwargs): block_event.wait(timeout=5.0) # Block for 5 seconds mock_response = Mock() mock_response.data.message.id = "stream-message-slow" return mock_response - + mock_chat_client.send_message.side_effect = slow_send_message - + # Add a message - msg = Message(original=None, content="Slow message", role="user", user_id="user1") + msg = Message( + original=None, content="Slow message", role="user", user_id="user1" + ) stream_conversation.add_message(msg) - + # Wait should timeout assert not stream_conversation.wait_for_pending_operations(timeout=0.5) - + # Unblock the operation block_event.set() - + # Now wait should succeed assert stream_conversation.wait_for_pending_operations(timeout=2.0) - + def test_shutdown_worker_thread(self, mock_chat_client, mock_channel): """Test that shutdown properly stops the worker thread.""" # Create a fresh conversation without using the fixture to avoid double shutdown @@ -706,18 +734,18 @@ def test_shutdown_worker_thread(self, mock_chat_client, mock_channel): instructions="Test", messages=[], channel=mock_channel, - chat_client=mock_chat_client + chat_client=mock_chat_client, ) - + # Verify thread is alive assert conversation._worker_thread.is_alive() - + # Shutdown conversation.shutdown() - + # Verify thread stopped assert not conversation._worker_thread.is_alive() - + # Verify shutdown flag is set assert conversation._shutdown is True @@ -726,20 +754,20 @@ def test_shutdown_worker_thread(self, mock_chat_client, mock_channel): def mock_stream_client(): """Create a mock Stream client for testing.""" from getstream import Stream - + client = Mock(spec=Stream) - + # Mock user creation mock_user = Mock() mock_user.id = "test-agent-user" mock_user.name = "Test Agent" client.create_user.return_value = mock_user - + # Mock video.call mock_call = Mock() mock_call.id = "test-call-123" client.video.call.return_value = mock_call - + return client @@ -751,28 +779,27 @@ def test_stream_conversation_integration(): if not os.getenv("STREAM_API_KEY"): pytest.skip("Stream credentials not available") - + # Create real client client = Stream.from_env() - + # Create a test channel and user user = client.create_user(id="test-user") - channel = client.chat.get_or_create_channel("messaging", str(uuid.uuid4()), data=ChannelInput(created_by_id=user.id)).data.channel + channel = client.chat.get_or_create_channel( + "messaging", str(uuid.uuid4()), data=ChannelInput(created_by_id=user.id) + ).data.channel # Create conversation conversation = StreamConversation( # noqa: F821 instructions="Test assistant", messages=[], channel=channel, - chat_client=client.chat + chat_client=client.chat, ) # Add a message message = Message( - original=None, - content="Hello from test", - role="user", - user_id=user.id + original=None, content="Hello from test", role="user", user_id=user.id ) conversation.add_message(message) @@ -784,58 +811,79 @@ def test_stream_conversation_integration(): assert message.id in conversation.internal_ids_to_stream_ids # update message with replace - conversation.update_message(message_id=message.id, input_text="Replaced content", user_id=user.id, replace_content=True, completed=True) + conversation.update_message( + message_id=message.id, + input_text="Replaced content", + user_id=user.id, + replace_content=True, + completed=True, + ) assert conversation.wait_for_pending_operations(timeout=5.0) - channel_data = client.chat.get_or_create_channel("messaging", channel.id, state=True).data + channel_data = client.chat.get_or_create_channel( + "messaging", channel.id, state=True + ).data assert len(channel_data.messages) == 1 assert channel_data.messages[0].text == "Replaced content" # Note: generating flag might not be in custom field depending on Stream API version # update message with delta - conversation.update_message(message_id=message.id, input_text=" more stuff", user_id=user.id, - replace_content=False, completed=True) + conversation.update_message( + message_id=message.id, + input_text=" more stuff", + user_id=user.id, + replace_content=False, + completed=True, + ) assert conversation.wait_for_pending_operations(timeout=5.0) - channel_data = client.chat.get_or_create_channel("messaging", channel.id, state=True).data + channel_data = client.chat.get_or_create_channel( + "messaging", channel.id, state=True + ).data assert len(channel_data.messages) == 1 assert channel_data.messages[0].text == "Replaced content more stuff" # Note: generating flag might not be in custom field depending on Stream API version - + # Test add_message with completed=False message2 = Message( original=None, content="Still generating...", role="assistant", - user_id="assistant" + user_id="assistant", ) conversation.add_message(message2, completed=False) assert conversation.wait_for_pending_operations(timeout=5.0) time.sleep(0.2) # Give extra time for update operation - - channel_data = client.chat.get_or_create_channel("messaging", channel.id, state=True).data + + channel_data = client.chat.get_or_create_channel( + "messaging", channel.id, state=True + ).data assert len(channel_data.messages) == 2 assert channel_data.messages[1].text == "Still generating..." # Note: generating flag might not be in custom field depending on Stream API version - + # Test streaming handle API - handle = conversation.start_streaming_message(role="assistant", initial_content="Thinking") + handle = conversation.start_streaming_message( + role="assistant", initial_content="Thinking" + ) assert conversation.wait_for_pending_operations(timeout=5.0) time.sleep(0.2) # Give extra time for update operation - + conversation.append_to_message(handle, "...") assert conversation.wait_for_pending_operations(timeout=5.0) - + conversation.replace_message(handle, "The answer is 42") assert conversation.wait_for_pending_operations(timeout=5.0) - + conversation.complete_message(handle) assert conversation.wait_for_pending_operations(timeout=5.0) - - channel_data = client.chat.get_or_create_channel("messaging", channel.id, state=True).data + + channel_data = client.chat.get_or_create_channel( + "messaging", channel.id, state=True + ).data assert len(channel_data.messages) == 3 assert channel_data.messages[2].text == "The answer is 42" # Note: generating flag might not be in custom field depending on Stream API version - + # Cleanup conversation.shutdown() diff --git a/tests/test_events.py b/tests/test_events.py index 192cdff7..b2822e27 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -58,6 +58,7 @@ async def my_handler(event: ValidEvent): await manager.wait() assert my_handler.value == 2 + @pytest.mark.asyncio async def test_subscribe_with_multiple_events_different(): manager = EventManager() @@ -65,6 +66,7 @@ async def test_subscribe_with_multiple_events_different(): manager.register(AnotherEvent) with pytest.raises(RuntimeError): + @manager.subscribe async def multi_event_handler(event1: ValidEvent, event2: AnotherEvent): pass @@ -76,6 +78,7 @@ async def test_subscribe_with_multiple_events_as_one_processes(): manager.register(ValidEvent) manager.register(AnotherEvent) value = 0 + @manager.subscribe async def multi_event_handler(event: ValidEvent | AnotherEvent): nonlocal value @@ -93,6 +96,7 @@ async def test_subscribe_unregistered_event_raises_key_error(): manager = EventManager(ignore_unknown_events=False) with pytest.raises(KeyError): + @manager.subscribe async def unknown_handler(event: ValidEvent): pass @@ -126,7 +130,6 @@ async def exception_handler(event: ExceptionEvent): assert recursive_counter["count"] == 2 - @pytest.mark.asyncio async def test_send_unknown_event_type_raises_key_error(): manager = EventManager(ignore_unknown_events=False) @@ -148,68 +151,70 @@ async def test_merge_managers_events_processed_in_one(): # Create two separate managers manager1 = EventManager() manager2 = EventManager() - + # Register different events in each manager manager1.register(ValidEvent) manager2.register(AnotherEvent) - + # Set up handlers in each manager all_events_processed = [] - + @manager1.subscribe async def manager1_handler(event: ValidEvent): all_events_processed.append(("manager1", event)) - + @manager2.subscribe async def manager2_handler(event: AnotherEvent): all_events_processed.append(("manager2", event)) - + # Send events to both managers before merging manager1.send(ValidEvent(field=1)) manager2.send(AnotherEvent(value="test")) - + # Wait for events to be processed in their respective managers await manager1.wait() await manager2.wait() - + # Verify events were processed in their original managers assert len(all_events_processed) == 2 assert all_events_processed[0][0] == "manager1" assert all_events_processed[0][1].field == 1 assert all_events_processed[1][0] == "manager2" assert all_events_processed[1][1].value == "test" - + # Clear the processed events list all_events_processed.clear() - + # Merge manager2 into manager1 manager1.merge(manager2) - + # Verify that manager2's processing task is stopped assert manager2._processing_task is None - + # Send new events to both managers after merging manager1.send(ValidEvent(field=2)) manager2.send(AnotherEvent(value="merged")) - + # Wait for events to be processed (only manager1's task should be running) await manager1.wait() - + # After merging, both events should be processed by manager1's task # (manager2's processing task should be stopped) assert len(all_events_processed) == 2 # Both events should be processed by manager1's task assert all_events_processed[0][0] == "manager1" # ValidEvent assert all_events_processed[0][1].field == 2 - assert all_events_processed[1][0] == "manager2" # AnotherEvent (handler from manager2) + assert ( + all_events_processed[1][0] == "manager2" + ) # AnotherEvent (handler from manager2) assert all_events_processed[1][1].value == "merged" - + # Verify that manager2 can still send events but they go to manager1's queue # and are processed by manager1's task all_events_processed.clear() manager2.send(AnotherEvent(value="from_manager2")) await manager1.wait() - + # The event from manager2 should be processed by manager1's task assert len(all_events_processed) == 1 assert all_events_processed[0][0] == "manager2" # Handler from manager2 @@ -220,51 +225,51 @@ async def manager2_handler(event: AnotherEvent): async def test_merge_managers_preserves_silent_events(caplog): """Test that when two managers are merged, silent events from both are preserved.""" import logging - + manager1 = EventManager() manager2 = EventManager() - + manager1.register(ValidEvent) manager2.register(AnotherEvent) - + # Mark ValidEvent as silent in manager1 manager1.silent(ValidEvent) # Mark AnotherEvent as silent in manager2 manager2.silent(AnotherEvent) - + handler_called = [] - + @manager1.subscribe async def valid_handler(event: ValidEvent): handler_called.append("valid") - + @manager2.subscribe async def another_handler(event: AnotherEvent): handler_called.append("another") - + # Merge manager2 into manager1 manager1.merge(manager2) - + # Verify that both silent events are preserved assert "custom.validevent" in manager1._silent_events assert "custom.anotherevent" in manager1._silent_events - + # Verify that manager2 also references the merged silent events assert manager2._silent_events is manager1._silent_events - + # Capture logs at INFO level with caplog.at_level(logging.INFO): # Send both events manager1.send(ValidEvent(field=42)) manager1.send(AnotherEvent(value="test")) await manager1.wait() - + # Both handlers should have been called assert handler_called == ["valid", "another"] - + # Check log messages log_messages = [record.message for record in caplog.records] - + # Should NOT see "Called handler" for either event (both are silent) assert not any("Called handler valid_handler" in msg for msg in log_messages) assert not any("Called handler another_handler" in msg for msg in log_messages) @@ -274,42 +279,48 @@ async def another_handler(event: AnotherEvent): async def test_silent_suppresses_handler_logging(caplog): """Test that marking an event as silent suppresses the 'Called handler' log message.""" import logging - + manager = EventManager() manager.register(ValidEvent) manager.register(AnotherEvent) - + handler_called = [] - + @manager.subscribe async def valid_handler(event: ValidEvent): handler_called.append("valid") - + @manager.subscribe async def another_handler(event: AnotherEvent): handler_called.append("another") - + # Mark ValidEvent as silent manager.silent(ValidEvent) - + # Capture logs at INFO level with caplog.at_level(logging.INFO): # Send both events manager.send(ValidEvent(field=42)) manager.send(AnotherEvent(value="test")) await manager.wait() - + # Both handlers should have been called assert handler_called == ["valid", "another"] - + # Check log messages log_messages = [record.message for record in caplog.records] - + # Should NOT see "Called handler" for ValidEvent (it's silent) - assert not any("Called handler valid_handler" in msg and "custom.validevent" in msg for msg in log_messages) - + assert not any( + "Called handler valid_handler" in msg and "custom.validevent" in msg + for msg in log_messages + ) + # SHOULD see "Called handler" for AnotherEvent (not silent) - assert any("Called handler another_handler" in msg and "custom.anotherevent" in msg for msg in log_messages) + assert any( + "Called handler another_handler" in msg and "custom.anotherevent" in msg + for msg in log_messages + ) @pytest.mark.asyncio @@ -317,76 +328,78 @@ async def another_handler(event: AnotherEvent): async def test_protobuf_events_with_base_event(): """Test that event manager handles protobuf events that inherit from BaseEvent.""" from vision_agents.core.events.manager import EventManager - from vision_agents.core.edge.sfu_events import AudioLevelEvent, ParticipantJoinedEvent + from vision_agents.core.edge.sfu_events import ( + AudioLevelEvent, + ParticipantJoinedEvent, + ) from getstream.video.rtc.pb.stream.video.sfu.event import events_pb2 from getstream.video.rtc.pb.stream.video.sfu.models import models_pb2 - + manager = EventManager() - + # Register generated protobuf event classes manager.register(AudioLevelEvent) manager.register(ParticipantJoinedEvent) - + assert AudioLevelEvent.type in manager._events assert ParticipantJoinedEvent.type in manager._events - + # Test 1: Send wrapped protobuf event with BaseEvent fields - proto_audio = events_pb2.AudioLevel(user_id='user123', level=0.85, is_speaking=True) - wrapped_event = AudioLevelEvent.from_proto(proto_audio, session_id='session123') - + proto_audio = events_pb2.AudioLevel(user_id="user123", level=0.85, is_speaking=True) + wrapped_event = AudioLevelEvent.from_proto(proto_audio, session_id="session123") + received_audio_events = [] - + @manager.subscribe async def handle_audio(event: AudioLevelEvent): received_audio_events.append(event) - + manager.send(wrapped_event) await manager.wait() - + assert len(received_audio_events) == 1 - assert received_audio_events[0].user_id == 'user123' - assert received_audio_events[0].session_id == 'session123' + assert received_audio_events[0].user_id == "user123" + assert received_audio_events[0].session_id == "session123" assert received_audio_events[0].is_speaking is True assert abs(received_audio_events[0].level - 0.85) < 0.01 - assert hasattr(received_audio_events[0], 'event_id') - assert hasattr(received_audio_events[0], 'timestamp') - + assert hasattr(received_audio_events[0], "event_id") + assert hasattr(received_audio_events[0], "timestamp") + # Test 2: Send raw protobuf message (auto-wrapped) - proto_raw = events_pb2.AudioLevel(user_id='user456', level=0.95, is_speaking=False) - + proto_raw = events_pb2.AudioLevel(user_id="user456", level=0.95, is_speaking=False) + received_audio_events.clear() manager.send(proto_raw) await manager.wait() - + assert len(received_audio_events) == 1 - assert received_audio_events[0].user_id == 'user456' + assert received_audio_events[0].user_id == "user456" assert abs(received_audio_events[0].level - 0.95) < 0.01 assert received_audio_events[0].is_speaking is False - assert hasattr(received_audio_events[0], 'event_id') - + assert hasattr(received_audio_events[0], "event_id") + # Test 3: Create event without protobuf payload (all fields optional) empty_event = AudioLevelEvent() assert empty_event.payload is None assert empty_event.user_id is None assert empty_event.event_id is not None - + # Test 4: Multiple protobuf event types received_participant_events = [] - + @manager.subscribe async def handle_participant(event: ParticipantJoinedEvent): received_participant_events.append(event) - + participant = models_pb2.Participant(user_id="user789", session_id="sess456") proto_participant = events_pb2.ParticipantJoined( - call_cid="call123", - participant=participant + call_cid="call123", participant=participant ) - + manager.send(proto_participant) await manager.wait() - + assert len(received_participant_events) == 1 assert received_participant_events[0].call_cid == "call123" assert received_participant_events[0].participant is not None - assert hasattr(received_participant_events[0], 'event_id') + assert hasattr(received_participant_events[0], "event_id") diff --git a/tests/test_function_calling.py b/tests/test_function_calling.py index 54dd3912..36c2f8cb 100644 --- a/tests/test_function_calling.py +++ b/tests/test_function_calling.py @@ -14,112 +14,112 @@ class TestFunctionRegistry: """Test the FunctionRegistry class.""" - + def test_register_function(self): """Test registering a function.""" registry = FunctionRegistry() - + @registry.register(description="Test function") def test_func(x: int, y: int = 5) -> int: """Test function with default parameter.""" return x + y - + assert "test_func" in registry._functions assert registry._functions["test_func"].description == "Test function" assert len(registry._functions["test_func"].parameters) == 2 - + def test_call_function(self): """Test calling a registered function.""" registry = FunctionRegistry() - + @registry.register(description="Add two numbers") def add_numbers(a: int, b: int) -> int: """Add two numbers.""" return a + b - + result = registry.call_function("add_numbers", {"a": 5, "b": 3}) assert result == 8 - + def test_call_function_with_defaults(self): """Test calling a function with default parameters.""" registry = FunctionRegistry() - + @registry.register(description="Test function with defaults") def test_func(x: int, y: int = 10) -> int: """Test function with default parameter.""" return x + y - + # Test with both parameters result = registry.call_function("test_func", {"x": 5, "y": 3}) assert result == 8 - + # Test with default parameter result = registry.call_function("test_func", {"x": 5}) assert result == 15 - + def test_call_nonexistent_function(self): """Test calling a non-existent function raises error.""" registry = FunctionRegistry() - + with pytest.raises(KeyError): registry.call_function("nonexistent", {}) - + def test_call_function_missing_required_param(self): """Test calling a function with missing required parameter raises error.""" registry = FunctionRegistry() - + @registry.register(description="Test function") def test_func(x: int, y: int) -> int: """Test function.""" return x + y - + with pytest.raises(TypeError): registry.call_function("test_func", {"x": 5}) - + def test_get_tool_schemas(self): """Test getting tool schemas.""" registry = FunctionRegistry() - + @registry.register(description="Test function") def test_func(x: int, y: int = 5) -> int: """Test function.""" return x + y - + schemas = registry.get_tool_schemas() assert len(schemas) == 1 assert schemas[0]["name"] == "test_func" assert schemas[0]["description"] == "Test function" assert "parameters_schema" in schemas[0] - + def test_get_callable(self): """Test getting callable function.""" registry = FunctionRegistry() - + @registry.register(description="Test function") def test_func(x: int) -> int: """Test function.""" return x * 2 - + callable_func = registry.get_callable("test_func") assert callable_func(5) == 10 - + with pytest.raises(KeyError): registry.get_callable("nonexistent") class TestGlobalRegistry: """Test the global function registry.""" - + def test_global_registry(self): """Test that the global registry works.""" # Clear any existing functions function_registry._functions.clear() - + @function_registry.register(description="Global test function") def global_test_func(x: int) -> int: """Global test function.""" return x * 3 - + assert "global_test_func" in function_registry._functions result = function_registry.call_function("global_test_func", {"x": 4}) assert result == 12 @@ -127,34 +127,34 @@ def global_test_func(x: int) -> int: class TestLLMFunctionCalling: """Test LLM function calling functionality.""" - + @pytest.mark.asyncio async def test_llm_function_registration(self): """Test that LLM can register functions.""" llm = LLM() - + @llm.register_function(description="Test function") def test_func(x: int) -> int: """Test function.""" return x * 2 - + functions = llm.get_available_functions() assert len(functions) == 1 assert functions[0]["name"] == "test_func" - + @pytest.mark.asyncio async def test_llm_get_available_functions(self): """Test getting available functions from LLM.""" llm = LLM() - + @llm.register_function(description="Function 1") def func1(x: int) -> int: return x + 1 - + @llm.register_function(description="Function 2") def func2(x: int) -> int: return x * 2 - + functions = llm.get_available_functions() assert len(functions) == 2 function_names = [f["name"] for f in functions] @@ -164,60 +164,68 @@ def func2(x: int) -> int: class TestOpenAIFunctionCalling: """Test OpenAI function calling functionality.""" - + @pytest.mark.asyncio - @patch('vision_agents.plugins.openai.openai_llm.AsyncOpenAI') + @patch("vision_agents.plugins.openai.openai_llm.AsyncOpenAI") async def test_openai_function_calling_response(self, mock_openai): """Test OpenAI function calling response.""" # Mock the OpenAI client and response mock_client = Mock() mock_openai.return_value = mock_client - + # Mock the responses.create call mock_response = Mock() mock_response.output = [ - Mock(type="function_call", call_id="call_123", arguments='{"location": "New York"}') + Mock( + type="function_call", + call_id="call_123", + arguments='{"location": "New York"}', + ) ] mock_client.responses.create.return_value = mock_response - + llm = OpenAILLM(api_key="test-key", model="gpt-4") - + # Register a test function @llm.register_function(description="Get weather for a location") def get_weather(location: str) -> str: """Get weather information.""" return f"Weather in {location}: Sunny, 72°F" - + # Test that function is registered functions = llm.get_available_functions() assert len(functions) == 1 assert functions[0]["name"] == "get_weather" - + # Test function calling result = llm.call_function("get_weather", {"location": "New York"}) assert result == "Weather in New York: Sunny, 72°F" - - @patch('vision_agents.plugins.openai.openai_llm.AsyncOpenAI') + + @patch("vision_agents.plugins.openai.openai_llm.AsyncOpenAI") async def test_openai_conversational_response(self, mock_openai): """Test OpenAI conversational response generation.""" mock_client = Mock() mock_openai.return_value = mock_client - + # Mock the responses.create call mock_response = Mock() mock_response.output = [ - Mock(type="function_call", call_id="call_123", arguments='{"location": "New York"}') + Mock( + type="function_call", + call_id="call_123", + arguments='{"location": "New York"}', + ) ] mock_client.responses.create.return_value = mock_response - + llm = OpenAILLM(api_key="test-key", model="gpt-4") - + # Register a test function @llm.register_function(description="Get weather for a location") def get_weather(location: str) -> str: """Get weather information.""" return f"Weather in {location}: Sunny, 72°F" - + # Test that function is registered functions = llm.get_available_functions() assert len(functions) == 1 @@ -226,60 +234,70 @@ def get_weather(location: str) -> str: class TestClaudeFunctionCalling: """Test Claude function calling functionality.""" - + @pytest.mark.asyncio - @patch('vision_agents.plugins.anthropic.anthropic_llm.AsyncAnthropic') + @patch("vision_agents.plugins.anthropic.anthropic_llm.AsyncAnthropic") async def test_claude_function_calling_response(self, mock_anthropic): """Test Claude function calling response.""" # Mock the Anthropic client and response mock_client = Mock() mock_anthropic.return_value = mock_client - + # Mock the messages.create call mock_response = Mock() mock_response.content = [ - Mock(type="tool_use", id="tool_123", name="get_weather", input={"location": "New York"}) + Mock( + type="tool_use", + id="tool_123", + name="get_weather", + input={"location": "New York"}, + ) ] mock_client.messages.create.return_value = mock_response - + llm = ClaudeLLM(api_key="test-key", model="claude-3-5-sonnet-20241022") - + # Register a test function @llm.register_function(description="Get weather for a location") def get_weather(location: str) -> str: """Get weather information.""" return f"Weather in {location}: Sunny, 72°F" - + # Test that function is registered functions = llm.get_available_functions() assert len(functions) == 1 assert functions[0]["name"] == "get_weather" - + # Test function calling result = llm.call_function("get_weather", {"location": "New York"}) assert result == "Weather in New York: Sunny, 72°F" - - @patch('vision_agents.plugins.anthropic.anthropic_llm.AsyncAnthropic') + + @patch("vision_agents.plugins.anthropic.anthropic_llm.AsyncAnthropic") async def test_claude_conversational_response(self, mock_anthropic): """Test Claude conversational response generation.""" mock_client = Mock() mock_anthropic.return_value = mock_client - + # Mock the messages.create call mock_response = Mock() mock_response.content = [ - Mock(type="tool_use", id="tool_123", name="get_weather", input={"location": "New York"}) + Mock( + type="tool_use", + id="tool_123", + name="get_weather", + input={"location": "New York"}, + ) ] mock_client.messages.create.return_value = mock_response - + llm = ClaudeLLM(api_key="test-key", model="claude-3-5-sonnet-20241022") - + # Register a test function @llm.register_function(description="Get weather for a location") def get_weather(location: str) -> str: """Get weather information.""" return f"Weather in {location}: Sunny, 72°F" - + # Test that function is registered functions = llm.get_available_functions() assert len(functions) == 1 @@ -288,67 +306,85 @@ def get_weather(location: str) -> str: class TestGeminiFunctionCalling: """Test Gemini function calling functionality.""" - + @pytest.mark.asyncio - @patch('vision_agents.plugins.gemini.gemini_llm.genai') + @patch("vision_agents.plugins.gemini.gemini_llm.genai") async def test_gemini_function_calling_response(self, mock_genai): """Test Gemini function calling response.""" # Mock the Gemini client and response mock_client = Mock() mock_genai.configure.return_value = None mock_genai.Chat.return_value = mock_client - + # Mock the send_message_stream call mock_response = Mock() mock_response.candidates = [ - Mock(content=Mock(parts=[ - Mock(type="function_call", function_call=Mock(name="get_weather", args={"location": "New York"})) - ])) + Mock( + content=Mock( + parts=[ + Mock( + type="function_call", + function_call=Mock( + name="get_weather", args={"location": "New York"} + ), + ) + ] + ) + ) ] mock_client.send_message_stream.return_value = [mock_response] - + llm = GeminiLLM(model="gemini-2.0-flash") - + # Register a test function @llm.register_function(description="Get weather for a location") def get_weather(location: str) -> str: """Get weather information.""" return f"Weather in {location}: Sunny, 72°F" - + # Test that function is registered functions = llm.get_available_functions() assert len(functions) == 1 assert functions[0]["name"] == "get_weather" - + # Test function calling result = llm.call_function("get_weather", {"location": "New York"}) assert result == "Weather in New York: Sunny, 72°F" - + @pytest.mark.asyncio - @patch('vision_agents.plugins.gemini.gemini_llm.genai') + @patch("vision_agents.plugins.gemini.gemini_llm.genai") async def test_gemini_conversational_response(self, mock_genai): """Test Gemini conversational response generation.""" mock_client = Mock() mock_genai.configure.return_value = None mock_genai.Chat.return_value = mock_client - + # Mock the send_message_stream call mock_response = Mock() mock_response.candidates = [ - Mock(content=Mock(parts=[ - Mock(type="function_call", function_call=Mock(name="get_weather", args={"location": "New York"})) - ])) + Mock( + content=Mock( + parts=[ + Mock( + type="function_call", + function_call=Mock( + name="get_weather", args={"location": "New York"} + ), + ) + ] + ) + ) ] mock_client.send_message_stream.return_value = [mock_response] - + llm = GeminiLLM(model="gemini-2.0-flash") - + # Register a test function @llm.register_function(description="Get weather for a location") def get_weather(location: str) -> str: """Get weather information.""" return f"Weather in {location}: Sunny, 72°F" - + # Test that function is registered functions = llm.get_available_functions() assert len(functions) == 1 @@ -357,85 +393,82 @@ def get_weather(location: str) -> str: class TestFunctionCallingIntegration: """Test function calling integration scenarios.""" - + @pytest.mark.asyncio async def test_tool_call_processing(self): """Test processing tool calls with multiple functions.""" llm = LLM() - + @llm.register_function(description="Get weather") def get_weather(location: str) -> str: return f"Weather in {location}: Sunny" - + @llm.register_function(description="Calculate sum") def calculate_sum(a: int, b: int) -> int: return a + b - + # Test multiple function registrations functions = llm.get_available_functions() assert len(functions) == 2 - + # Test calling both functions weather_result = llm.call_function("get_weather", {"location": "NYC"}) sum_result = llm.call_function("calculate_sum", {"a": 5, "b": 3}) - + assert weather_result == "Weather in NYC: Sunny" assert sum_result == 8 - + @pytest.mark.asyncio async def test_error_handling_in_function_calls(self): """Test error handling in function calls.""" llm = LLM() - + @llm.register_function(description="Test function that raises error") def error_function(x: int) -> int: if x < 0: raise ValueError("Negative numbers not allowed") return x * 2 - + # Test normal case result = llm.call_function("error_function", {"x": 5}) assert result == 10 - + # Test error case with pytest.raises(ValueError): llm.call_function("error_function", {"x": -5}) - + @pytest.mark.asyncio async def test_function_schema_generation(self): """Test that function schemas are generated correctly.""" llm = LLM() - + @llm.register_function(description="Complex function") def complex_function( - name: str, - age: int, - is_active: bool = True, - tags: list = None + name: str, age: int, is_active: bool = True, tags: list = None ) -> dict: """Complex function with various parameter types.""" return { "name": name, "age": age, "is_active": is_active, - "tags": tags or [] + "tags": tags or [], } - + schemas = llm.get_available_functions() assert len(schemas) == 1 - + schema = schemas[0] assert schema["name"] == "complex_function" assert schema["description"] == "Complex function" assert "parameters_schema" in schema - + # Check parameter types params = schema["parameters_schema"]["properties"] assert "name" in params assert "age" in params assert "is_active" in params assert "tags" in params - + # Check required parameters required = schema["parameters_schema"]["required"] assert "name" in required @@ -446,87 +479,93 @@ def complex_function( class TestConcurrentToolExecution: """Test concurrent tool execution functionality.""" - + @pytest.mark.asyncio async def test_dedup_and_execute(self): """Test the _dedup_and_execute method.""" llm = LLM() - + @llm.register_function(description="Test function") def test_func(x: int) -> int: return x * 2 - + # Test with duplicate tool calls tool_calls = [ {"id": "call1", "name": "test_func", "arguments_json": {"x": 5}}, - {"id": "call2", "name": "test_func", "arguments_json": {"x": 5}}, # Duplicate + { + "id": "call2", + "name": "test_func", + "arguments_json": {"x": 5}, + }, # Duplicate {"id": "call3", "name": "test_func", "arguments_json": {"x": 3}}, ] - + # This should deduplicate and only execute call1 and call3 triples, seen = await llm._dedup_and_execute(tool_calls) # The deduplication should work, but let's check what actually happens # The key is based on (id, name, arguments_json), so different IDs = different keys assert len(triples) == 3 # All calls have different IDs, so all are executed assert len(seen) == 3 # 3 unique keys in seen set - + # Check results results = [result for _, result, _ in triples] assert 10 in results # 5 * 2 (appears twice) - assert 6 in results # 3 * 2 - + assert 6 in results # 3 * 2 + @pytest.mark.asyncio async def test_tool_lifecycle_events(self): """Test that tool lifecycle events are emitted.""" from vision_agents.core.llm.events import ToolStartEvent, ToolEndEvent - + llm = LLM() - + @llm.register_function(description="Test function") def test_func(x: int) -> int: return x * 2 - + # Track emitted events start_events = [] end_events = [] - + @llm.events.subscribe async def track_start_event(event: ToolStartEvent): start_events.append(event) - + @llm.events.subscribe async def track_end_event(event: ToolEndEvent): end_events.append(event) - + # Execute a tool call - await llm._run_one_tool({"id": "call1", "name": "test_func", "arguments_json": {"x": 5}}, 30.0) + await llm._run_one_tool( + {"id": "call1", "name": "test_func", "arguments_json": {"x": 5}}, 30.0 + ) # Wait for events await llm.events.wait(timeout=1.0) - + # Check that events were emitted assert len(start_events) == 1 assert len(end_events) == 1 assert start_events[0].tool_name == "test_func" assert end_events[0].tool_name == "test_func" assert end_events[0].success is True - + @pytest.mark.asyncio async def test_output_sanitization(self): """Test output sanitization for large responses.""" llm = LLM() - + # Test normal output normal_output = "Hello world" sanitized = llm._sanitize_tool_output(normal_output) assert sanitized == "Hello world" - + # Test large output large_output = "x" * 70000 # Larger than default 60k limit sanitized = llm._sanitize_tool_output(large_output) assert len(sanitized) == 60001 # 60k + "…" assert sanitized.endswith("…") - + # Test non-string output dict_output = {"key": "value"} sanitized = llm._sanitize_tool_output(dict_output) - assert sanitized == '{"key": "value"}' \ No newline at end of file + assert sanitized == '{"key": "value"}' diff --git a/tests/test_mcp_integration.py b/tests/test_mcp_integration.py index 5b44332b..1f7a0d90 100644 --- a/tests/test_mcp_integration.py +++ b/tests/test_mcp_integration.py @@ -12,6 +12,7 @@ load_dotenv() + def get_mcp_server(): """Get configured MCP server based on environment variables.""" local_cmd = os.getenv("MCP_LOCAL_CMD") @@ -135,6 +136,7 @@ async def test_mcp_multiple_tool_calls(): tools_after = await server.list_tools() assert len(tools_after) == len(tools), "Tool count changed unexpectedly" + @pytest.mark.integration @pytest.mark.asyncio async def test_mcp_resources(): @@ -171,7 +173,7 @@ async def test_mcp_concurrent_calls(): url="https://api.githubcopilot.com/mcp/", headers={"Authorization": f"Bearer {github_pat}"}, timeout=30.0, - session_timeout=60.0 + session_timeout=60.0, ) async with server: @@ -195,7 +197,7 @@ async def call_tool(tool, args): else: # Generic arguments for unknown GitHub tools args = {"query": f"concurrent_test_{i}"} - + tasks.append(call_tool(tool, args)) results = await asyncio.gather(*tasks, return_exceptions=True) @@ -227,9 +229,9 @@ async def test_mcp_tool_schema_validation(): assert tool.name, f"Tool name is empty: {tool}" # Verify inputSchema is a dict - assert isinstance( - tool.inputSchema, dict - ), f"Tool inputSchema should be dict: {tool.inputSchema}" + assert isinstance(tool.inputSchema, dict), ( + f"Tool inputSchema should be dict: {tool.inputSchema}" + ) @pytest.mark.integration @@ -263,9 +265,9 @@ async def test_mcp_github_integration(): ] found_github_tools = [name for name in github_tools if name in tool_names] - assert ( - len(found_github_tools) > 0 - ), f"No expected GitHub tools found. Available: {tool_names}" + assert len(found_github_tools) > 0, ( + f"No expected GitHub tools found. Available: {tool_names}" + ) # Test a simple GitHub tool call if "search_repositories" in tool_names: @@ -304,13 +306,13 @@ async def test_mcp_weather_integration(): @pytest.mark.asyncio async def test_openai_llm_mcp_weather_integration(): """Test OpenAI LLM integration with MCP weather server. - + This test verifies the complete flow: 1. Agent connects to MCP weather server - 2. MCP tools are registered with LLM function registry + 2. MCP tools are registered with LLM function registry 3. LLM makes function calls to MCP tools 4. Tool results are processed and returned - + Requires: - OPENAI_API_KEY environment variable - MCP_LOCAL_CMD pointing to weather server @@ -319,19 +321,25 @@ async def test_openai_llm_mcp_weather_integration(): # Skip if credentials not available if not os.getenv("OPENAI_API_KEY"): pytest.skip("OPENAI_API_KEY not set, skipping OpenAI MCP integration test") - if not os.getenv("MCP_LOCAL_CMD") or "transport.py" not in os.getenv("MCP_LOCAL_CMD", ""): + if not os.getenv("MCP_LOCAL_CMD") or "transport.py" not in os.getenv( + "MCP_LOCAL_CMD", "" + ): pytest.skip("MCP_LOCAL_CMD not set to weather server, skipping test") if not os.getenv("STREAM_API_KEY") or not os.getenv("STREAM_API_SECRET"): pytest.skip("STREAM_API_KEY or STREAM_API_SECRET not set, skipping test") # Setup components - llm = OpenAILLM(model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY")) # Use cheaper model - weather_server = MCPServerLocal(command=os.getenv("MCP_LOCAL_CMD"), session_timeout=60.0) - + llm = OpenAILLM( + model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY") + ) # Use cheaper model + weather_server = MCPServerLocal( + command=os.getenv("MCP_LOCAL_CMD"), session_timeout=60.0 + ) + # Create real edge and agent user edge = getstream.Edge() agent_user = User(name="Weather Assistant", id="weather-agent") - + # Create agent with required processing capabilities agent = Agent( edge=edge, @@ -342,16 +350,16 @@ async def test_openai_llm_mcp_weather_integration(): tts=elevenlabs.TTS(), stt=deepgram.STT(), ) - + try: # Connect to MCP server await agent._connect_mcp_servers() - + # Verify tools are registered available_functions = agent.llm.get_available_functions() - mcp_functions = [f for f in available_functions if f['name'].startswith('mcp_')] + mcp_functions = [f for f in available_functions if f["name"].startswith("mcp_")] assert len(mcp_functions) > 0, "No MCP tools registered" - + # Test function calling response = await agent.llm.simple_response( text="What's the weather like in London?", @@ -359,6 +367,6 @@ async def test_openai_llm_mcp_weather_integration(): # Verify response was received (the core integration test) assert response is not None, "No response received from LLM" - + finally: - await agent.close() \ No newline at end of file + await agent.close() diff --git a/tests/test_openai_function_calling_integration.py b/tests/test_openai_function_calling_integration.py index 04ec20c4..7d066629 100644 --- a/tests/test_openai_function_calling_integration.py +++ b/tests/test_openai_function_calling_integration.py @@ -25,7 +25,9 @@ async def test_openai_function_calling_live_roundtrip(): # Side-effect to prove the tool actually ran calls: list[str] = [] - @llm.register_function(description="Probe tool that records invocation and returns a marker string") + @llm.register_function( + description="Probe tool that records invocation and returns a marker string" + ) def probe_tool(ping: str) -> str: calls.append(ping) return f"probe_ok:{ping}" diff --git a/tests/test_queue_and_video_forwarder.py b/tests/test_queue_and_video_forwarder.py index ea237863..822e2380 100644 --- a/tests/test_queue_and_video_forwarder.py +++ b/tests/test_queue_and_video_forwarder.py @@ -4,84 +4,85 @@ from vision_agents.core.utils.queue import LatestNQueue from vision_agents.core.utils.video_forwarder import VideoForwarder + class TestLatestNQueue: """Test suite for LatestNQueue""" - + @pytest.mark.asyncio async def test_basic_put_get(self): """Test basic put and get operations""" queue = LatestNQueue[int](maxlen=3) - + await queue.put_latest(1) await queue.put_latest(2) await queue.put_latest(3) - + assert await queue.get() == 1 assert await queue.get() == 2 assert await queue.get() == 3 - + @pytest.mark.asyncio async def test_put_latest_discards_oldest(self): """Test that put_latest discards oldest items when full""" queue = LatestNQueue[int](maxlen=2) - + await queue.put_latest(1) await queue.put_latest(2) await queue.put_latest(3) # Should discard 1 - + assert await queue.get() == 2 assert await queue.get() == 3 - + # Queue should be empty now with pytest.raises(asyncio.QueueEmpty): queue.get_nowait() - + @pytest.mark.asyncio async def test_put_latest_nowait(self): """Test synchronous put_latest_nowait""" queue = LatestNQueue[int](maxlen=2) - + queue.put_latest_nowait(1) queue.put_latest_nowait(2) queue.put_latest_nowait(3) # Should discard 1 - + assert queue.get_nowait() == 2 assert queue.get_nowait() == 3 - + @pytest.mark.asyncio async def test_put_latest_nowait_discards_oldest(self): """Test that put_latest_nowait discards oldest when full""" queue = LatestNQueue[int](maxlen=3) - + # Fill queue queue.put_latest_nowait(1) queue.put_latest_nowait(2) queue.put_latest_nowait(3) - + # Add more items, should discard oldest queue.put_latest_nowait(4) # Discards 1 queue.put_latest_nowait(5) # Discards 2 - + # Should have 3, 4, 5 items = [] while not queue.empty(): items.append(queue.get_nowait()) - + assert items == [3, 4, 5] - + @pytest.mark.asyncio async def test_queue_size_limits(self): """Test that queue respects size limits""" queue = LatestNQueue[int](maxlen=1) - + await queue.put_latest(1) assert queue.full() - + # Adding another should discard the first await queue.put_latest(2) assert queue.full() assert await queue.get() == 2 - + @pytest.mark.asyncio async def test_generic_type_support(self): """Test that queue works with different types""" @@ -90,20 +91,20 @@ async def test_generic_type_support(self): await str_queue.put_latest("a") await str_queue.put_latest("b") await str_queue.put_latest("c") # Should discard "a" - + assert await str_queue.get() == "b" assert await str_queue.get() == "c" - + # Test with custom objects class TestObj: def __init__(self, value): self.value = value - + obj_queue = LatestNQueue[TestObj](maxlen=2) await obj_queue.put_latest(TestObj(1)) await obj_queue.put_latest(TestObj(2)) await obj_queue.put_latest(TestObj(3)) # Should discard first - + obj2 = await obj_queue.get() obj3 = await obj_queue.get() assert obj2.value == 2 @@ -112,247 +113,250 @@ def __init__(self, value): class TestVideoForwarder: """Test suite for VideoForwarder using real video data""" - + @pytest.mark.asyncio async def test_video_forwarder_initialization(self, bunny_video_track): """Test VideoForwarder initialization""" forwarder = VideoForwarder(bunny_video_track, max_buffer=5, fps=30.0) - + assert forwarder.input_track == bunny_video_track assert forwarder.queue.maxsize == 5 assert forwarder.fps == 30.0 assert len(forwarder._tasks) == 0 assert not forwarder._stopped.is_set() - + @pytest.mark.asyncio async def test_start_stop_lifecycle(self, bunny_video_track): """Test start and stop lifecycle""" forwarder = VideoForwarder(bunny_video_track, max_buffer=3) - + # Start forwarder await forwarder.start() assert len(forwarder._tasks) == 1 assert not forwarder._stopped.is_set() - + # Let it run briefly await asyncio.sleep(0.01) - + # Stop forwarder await forwarder.stop() assert len(forwarder._tasks) == 0 assert forwarder._stopped.is_set() - + @pytest.mark.asyncio async def test_next_frame_pull_model(self, bunny_video_track): """Test next_frame pull model""" forwarder = VideoForwarder(bunny_video_track, max_buffer=3) - + await forwarder.start() - + try: # Get first frame frame = await forwarder.next_frame(timeout=1.0) - assert hasattr(frame, 'to_ndarray') # Real video frame - + assert hasattr(frame, "to_ndarray") # Real video frame + # Get a few more frames for _ in range(3): frame = await forwarder.next_frame(timeout=1.0) - assert hasattr(frame, 'to_ndarray') # Real video frame - + assert hasattr(frame, "to_ndarray") # Real video frame + finally: await forwarder.stop() - + @pytest.mark.asyncio async def test_next_frame_coalesces_to_newest(self, bunny_video_track): """Test that next_frame coalesces backlog to newest frame""" forwarder = VideoForwarder(bunny_video_track, max_buffer=5) - + await forwarder.start() - + try: # Let multiple frames accumulate await asyncio.sleep(0.05) - + # Get frame - should be the newest available frame = await forwarder.next_frame(timeout=1.0) - assert hasattr(frame, 'to_ndarray') # Real video frame - + assert hasattr(frame, "to_ndarray") # Real video frame + finally: await forwarder.stop() - + @pytest.mark.asyncio async def test_callback_push_model(self, bunny_video_track): """Test callback-based push model""" forwarder = VideoForwarder(bunny_video_track, max_buffer=3, fps=10.0) - + received_frames = [] - + def on_frame(frame): received_frames.append(frame) - + await forwarder.start() - + try: # Start callback consumer await forwarder.start_event_consumer(on_frame) - + # Let it run and collect frames await asyncio.sleep(0.1) - + # Should have received some frames assert len(received_frames) > 0 for frame in received_frames: - assert hasattr(frame, 'to_ndarray') # Real video frame - + assert hasattr(frame, "to_ndarray") # Real video frame + finally: await forwarder.stop() - + @pytest.mark.asyncio async def test_async_callback_push_model(self, bunny_video_track): """Test async callback-based push model""" forwarder = VideoForwarder(bunny_video_track, max_buffer=3, fps=10.0) - + received_frames = [] - + async def async_on_frame(frame): received_frames.append(frame) await asyncio.sleep(0.001) # Simulate async work - + await forwarder.start() - + try: # Start async callback consumer await forwarder.start_event_consumer(async_on_frame) - + # Let it run and collect frames await asyncio.sleep(0.1) - + # Should have received some frames assert len(received_frames) > 0 for frame in received_frames: - assert hasattr(frame, 'to_ndarray') # Real video frame - + assert hasattr(frame, "to_ndarray") # Real video frame + finally: await forwarder.stop() - + @pytest.mark.asyncio async def test_fps_throttling(self, bunny_video_track): """Test FPS throttling in callback mode""" forwarder = VideoForwarder(bunny_video_track, max_buffer=3, fps=5.0) # 5 FPS - + received_frames = [] timestamps = [] - + def on_frame(frame): received_frames.append(frame) timestamps.append(asyncio.get_event_loop().time()) - + await forwarder.start() - + try: await forwarder.start_event_consumer(on_frame) - + # Let it run for a bit await asyncio.sleep(0.5) - + # Should have received frames assert len(received_frames) > 0 - + # Check that frames are throttled (roughly 5 FPS) if len(timestamps) > 1: - intervals = [timestamps[i+1] - timestamps[i] for i in range(len(timestamps)-1)] + intervals = [ + timestamps[i + 1] - timestamps[i] + for i in range(len(timestamps) - 1) + ] avg_interval = sum(intervals) / len(intervals) # Should be roughly 1/5 = 0.2 seconds between frames assert avg_interval >= 0.15 # Allow some tolerance - + finally: await forwarder.stop() - + @pytest.mark.asyncio async def test_producer_handles_track_errors(self, bunny_video_track): """Test that producer handles track errors gracefully""" # Mock track to raise exception after a few frames call_count = 0 original_recv = bunny_video_track.recv - + async def failing_recv(): nonlocal call_count call_count += 1 if call_count > 3: raise Exception("Track error") return await original_recv() - + bunny_video_track.recv = failing_recv - + forwarder = VideoForwarder(bunny_video_track, max_buffer=3) - + await forwarder.start() - + try: # Should still be able to get some frames before error frame = await forwarder.next_frame(timeout=1.0) - assert hasattr(frame, 'to_ndarray') # Real video frame - + assert hasattr(frame, "to_ndarray") # Real video frame + # Let it run a bit more to trigger error await asyncio.sleep(0.1) - + finally: await forwarder.stop() - + @pytest.mark.asyncio async def test_stop_drains_queue(self, bunny_video_track): """Test that stop drains the queue""" forwarder = VideoForwarder(bunny_video_track, max_buffer=5) - + await forwarder.start() - + try: # Let some frames accumulate await asyncio.sleep(0.05) - + # Stop should drain queue await forwarder.stop() - + # Queue should be empty after stop assert forwarder.queue.empty() - + except Exception: await forwarder.stop() - + @pytest.mark.asyncio async def test_no_fps_limit(self, bunny_video_track): """Test behavior when fps is None (no limit)""" forwarder = VideoForwarder(bunny_video_track, max_buffer=3, fps=None) - + received_frames = [] timestamps = [] - + def on_frame(frame): received_frames.append(frame) timestamps.append(asyncio.get_event_loop().time()) - + await forwarder.start() - + try: await forwarder.start_event_consumer(on_frame) - + # Let it run briefly await asyncio.sleep(0.1) - + # Should have received frames assert len(received_frames) > 0 - + # With no FPS limit, frames should come as fast as possible # (limited by track delay and processing time) - + finally: await forwarder.stop() - + async def test_bunny_video_track_frame_count(self, bunny_video_track): """Test how many frames are actually available from bunny_video_track""" frame_count = 0 frames = [] - + try: while True: frame = await bunny_video_track.recv() @@ -365,70 +369,74 @@ async def test_bunny_video_track_frame_count(self, bunny_video_track): print(f"Error after {frame_count} frames: {e}") assert frame_count == 45 - + async def test_frame_count_at_10fps(self, bunny_video_track): """Test that VideoForwarder generates ~30 frames at 10fps from 3-second video""" forwarder = VideoForwarder(bunny_video_track, max_buffer=10, fps=10.0) - + received_frames = [] timestamps = [] - + def on_frame(frame): received_frames.append(frame) timestamps.append(asyncio.get_event_loop().time()) - + await forwarder.start() - + try: await forwarder.start_event_consumer(on_frame) - + # Let it run for the full 3-second video duration await asyncio.sleep(10) # Slightly longer to ensure we get all frames - + # Should have received approximately 30 frames (3 seconds * 10 fps) # Allow some tolerance for timing variations - assert 25 <= len(received_frames) <= 35, f"Expected ~30 frames, got {len(received_frames)}" - + assert 25 <= len(received_frames) <= 35, ( + f"Expected ~30 frames, got {len(received_frames)}" + ) + finally: await forwarder.stop() class TestVideoForwarderIntegration: """Integration tests for VideoForwarder with real video data""" - + @pytest.mark.asyncio async def test_video_forwarder_with_callback_processing(self, bunny_video_track): """Test VideoForwarder with callback-based processing""" processed_frames = [] - + async def process_frame(frame): # Simulate frame processing processed_data = frame.to_ndarray() - processed_frames.append({ - 'data_shape': processed_data.shape, - 'has_to_ndarray': hasattr(frame, 'to_ndarray') - }) - + processed_frames.append( + { + "data_shape": processed_data.shape, + "has_to_ndarray": hasattr(frame, "to_ndarray"), + } + ) + forwarder = VideoForwarder(bunny_video_track, max_buffer=4, fps=20.0) - + await forwarder.start() - + try: await forwarder.start_event_consumer(process_frame) - + # Let processing run await asyncio.sleep(0.15) - + # Verify frames were processed assert len(processed_frames) > 0 - + # Verify processing data for processed in processed_frames: - assert 'data_shape' in processed - assert 'has_to_ndarray' in processed - assert processed['has_to_ndarray'] is True + assert "data_shape" in processed + assert "has_to_ndarray" in processed + assert processed["has_to_ndarray"] is True # Real video frames will have varying shapes - assert len(processed['data_shape']) == 3 # height, width, channels - + assert len(processed["data_shape"]) == 3 # height, width, channels + finally: await forwarder.stop() diff --git a/tests/test_realtime_base.py b/tests/test_realtime_base.py index 990c8958..34277202 100644 --- a/tests/test_realtime_base.py +++ b/tests/test_realtime_base.py @@ -33,7 +33,7 @@ def partial_update_message(self, text: str, participant: Any = None) -> None: def finish_last_message(self, text: str) -> None: self.finish_calls.append(text) - + def add_message(self, message: dict) -> None: """Add a message to the conversation.""" self.messages.append(message) @@ -66,40 +66,42 @@ async def simple_audio_response(self, pcm): async def _close_impl(self): return None - + # Additional methods from LLM base class def set_before_response_listener(self, callback): """Set before response callback.""" self.before_response_listener = callback - + def set_after_response_listener(self, callback): """Set after response callback.""" self.after_response_listener = callback - + async def wait_until_ready(self, timeout: float = 5.0) -> bool: """Wait until ready (already ready in fake).""" return True - + async def interrupt_playback(self): """Interrupt playback (no-op for fake).""" pass - + def resume_playback(self): """Resume playback (no-op for fake).""" pass + @pytest.mark.skip(reason="Conversation class has not fully been wired into Agent yet") @pytest.mark.asyncio async def test_agent_conversation_updates_with_realtime(): """Test that Agent wires Realtime events to conversation updates.""" from vision_agents.core.edge import EdgeTransport from vision_agents.core.edge.types import User, Connection - + # =================================================================== # Mock Connection - mimics the structure Agent expects # =================================================================== class MockConnection(Connection): """Mock connection with minimal structure for testing.""" + def __init__(self): super().__init__() # Agent.join() accesses connection._connection._coordinator_ws_client.on_wildcard() @@ -108,86 +110,81 @@ def __init__(self): on_wildcard=lambda *args, **kwargs: None ) ) - + async def close(self): pass - + # =================================================================== # Mock EdgeTransport - provides conversation and connection # =================================================================== class MockEdge(EdgeTransport): """Mock edge transport for testing Agent integration.""" + def __init__(self): super().__init__() self.conversation = None # EdgeTransport doesn't initialize events, but Agent expects it from vision_agents.core.events.manager import EventManager + self.events = EventManager() - + async def create_user(self, user: User): return user - + def create_audio_track(self): return None - + def close(self): pass - + def open_demo(self, *args, **kwargs): pass - + async def join(self, agent, call): """Return a mock connection.""" return MockConnection() - + async def publish_tracks(self, audio_track, video_track): pass - + async def create_conversation(self, call, user, instructions): """Return our fake conversation for testing.""" return self.conversation - + def add_track_subscriber(self, track_id): return None - + # =================================================================== # Fake Conversation - tracks partial and final updates # =================================================================== fake_conv = FakeConversation() - + # =================================================================== # Create Agent with new API # =================================================================== rt = FakeRealtime() mock_edge = MockEdge() mock_edge.conversation = fake_conv # Set before join - + agent_user = User(id="agent-123", name="Test Agent") - + agent = Agent( - edge=mock_edge, - llm=rt, - agent_user=agent_user, - instructions="Test instructions" + edge=mock_edge, llm=rt, agent_user=agent_user, instructions="Test instructions" ) - + # =================================================================== # Mock Call object # =================================================================== call = SimpleNamespace( id="test-call-123", - client=SimpleNamespace( - stream=SimpleNamespace( - chat=SimpleNamespace() - ) - ) + client=SimpleNamespace(stream=SimpleNamespace(chat=SimpleNamespace())), ) - + # =================================================================== # Join call (registers event handlers) # =================================================================== await agent.join(call) - + # =================================================================== # Trigger events through FakeRealtime # =================================================================== @@ -196,22 +193,24 @@ def add_track_subscriber(self, track_id): # 2. RealtimeResponseEvent (partial: "Hello") # 3. RealtimeResponseEvent (complete: "Hello world") await rt.send_text("Hi") - + # Allow async event handlers to run await asyncio.sleep(0.05) - + # Wait for event processing await agent.events.wait(timeout=1.0) - + # =================================================================== # Assertions - verify conversation received updates # =================================================================== - assert ("Hello", None) in fake_conv.partial_calls, \ + assert ("Hello", None) in fake_conv.partial_calls, ( f"Expected partial update 'Hello', got: {fake_conv.partial_calls}" - - assert "Hello world" in fake_conv.finish_calls, \ + ) + + assert "Hello world" in fake_conv.finish_calls, ( f"Expected finish call 'Hello world', got: {fake_conv.finish_calls}" - + ) + # Cleanup await agent.close() @@ -238,17 +237,17 @@ async def send_text(self, text: str): async def simple_response(self, text: str, processors=None, participant=None): """Aggregates streaming responses.""" # Call before listener if set - if hasattr(self, 'before_response_listener') and self.before_response_listener: + if hasattr(self, "before_response_listener") and self.before_response_listener: self.before_response_listener([{"role": "user", "content": text}]) - + await self.send_text(text) # Aggregate all response events ("Hi " + "there" + "!") result = LLMResponseEvent(original=None, text="Hi there!") - + # Call after listener if set - if hasattr(self, 'after_response_listener') and self.after_response_listener: + if hasattr(self, "after_response_listener") and self.after_response_listener: await self.after_response_listener(result) - + return result async def simple_audio_response(self, pcm): @@ -257,12 +256,12 @@ async def simple_audio_response(self, pcm): async def _close_impl(self): return None - + # Additional methods from LLM base class def set_before_response_listener(self, callback): """Set before response callback.""" self.before_response_listener = callback - + def set_after_response_listener(self, callback): """Set after response callback.""" self.after_response_listener = callback @@ -312,10 +311,10 @@ async def _on_disc(event: RealtimeDisconnectedEvent): await asyncio.sleep(0.01) await rt.close() - + # Wait for all events in queue to be processed await rt.events.wait(timeout=1.0) - + assert observed["disconnected"] is True @@ -358,7 +357,7 @@ async def simple_audio_response(self, pcm): async def _close_impl(self): return None - + # Additional method for native_response test async def native_response(self, **kwargs): """Native response aggregates streaming responses.""" diff --git a/tests/test_stt_base.py b/tests/test_stt_base.py index 83e79489..63fccda1 100644 --- a/tests/test_stt_base.py +++ b/tests/test_stt_base.py @@ -10,7 +10,11 @@ from dotenv import load_dotenv from vision_agents.core.stt.stt import STT -from vision_agents.core.stt.events import STTTranscriptEvent, STTPartialTranscriptEvent, STTErrorEvent +from vision_agents.core.stt.events import ( + STTTranscriptEvent, + STTPartialTranscriptEvent, + STTErrorEvent, +) from getstream.video.rtc.track_util import PcmData from vision_agents.core.agents import Agent from vision_agents.core.edge.types import User, Participant @@ -106,7 +110,7 @@ async def on_transcript(event: STTTranscriptEvent): metadata = {"confidence": 0.95, "processing_time_ms": 100} mock_stt._emit_transcript_event(text, user_metadata, metadata) - + # Wait for event processing await mock_stt.events.wait(timeout=1.0) @@ -135,7 +139,7 @@ async def on_partial_transcript(event: STTPartialTranscriptEvent): metadata = {"confidence": 0.8} mock_stt._emit_partial_transcript_event(text, user_metadata, metadata) - + # Wait for event processing await mock_stt.events.wait(timeout=1.0) @@ -160,7 +164,7 @@ async def on_error(event: STTErrorEvent): # Emit an error event test_error = Exception("Test error") mock_stt._emit_error_event(test_error, "test context") - + # Wait for event processing await mock_stt.events.wait(timeout=1.0) @@ -197,7 +201,7 @@ async def on_transcript(event: STTTranscriptEvent): # Process audio user_metadata = {"user_id": "123"} await mock_stt.process_audio(valid_pcm_data, user_metadata) - + # Wait for event processing await mock_stt.events.wait(timeout=1.0) @@ -246,7 +250,7 @@ async def on_error(event: STTErrorEvent): # Process audio (should not raise exception) await mock_stt_with_exception.process_audio(valid_pcm_data) - + # Wait for event processing await mock_stt_with_exception.events.wait(timeout=1.0) @@ -260,6 +264,7 @@ async def on_error(event: STTErrorEvent): # Integration Tests # ============================================================================ + class TestSTTIntegration(BaseTest): """Integration tests for STT with real components.""" @@ -267,10 +272,10 @@ class TestSTTIntegration(BaseTest): async def test_agent_stt_only_without_tts(self, mia_audio_16khz): """ Real integration test: Agent with STT but no TTS. - - Uses real components (Deepgram STT, OpenAI LLM, Stream Edge) + + Uses real components (Deepgram STT, OpenAI LLM, Stream Edge) to verify STT-only agents work end-to-end. - + This test verifies: - Agent can be created with STT but without TTS - Agent correctly identifies need for audio input @@ -284,14 +289,12 @@ async def test_agent_stt_only_without_tts(self, mia_audio_16khz): missing_keys = [key for key in required_keys if not os.getenv(key)] if missing_keys: pytest.skip(f"Missing required API keys: {', '.join(missing_keys)}") - - - + edge = getstream.Edge() llm = openai.LLM(model="gpt-4o-mini") # Create STT with correct sample rate to match our test audio stt = deepgram.STT(sample_rate=16000) - + # Create agent with STT but explicitly NO TTS agent = Agent( edge=edge, @@ -301,55 +304,57 @@ async def test_agent_stt_only_without_tts(self, mia_audio_16khz): tts=None, # ← KEY: No TTS - this is what we're testing instructions="You are a test agent for STT-only support.", ) - + # Test 1: Verify agent needs audio input (because STT is present) - assert agent._needs_audio_or_video_input() is True, \ + assert agent._needs_audio_or_video_input() is True, ( "Agent with STT should need audio input" - + ) + # Test 2: Verify agent does NOT publish audio (because TTS is None) - assert agent.publish_audio is False, \ + assert agent.publish_audio is False, ( "Agent without TTS should not publish audio" - + ) + # Test 3: Set up event listeners to capture transcript transcript_events = [] - + @agent.events.subscribe async def on_transcript(event: STTTranscriptEvent): transcript_events.append(event) - + # Test 4: Create a test participant (user sending audio) test_user = User(name="Test User", id="test_user") test_participant = Participant( original=test_user, # The original user object user_id="test_user", # User ID ) - + # Test 5: Send real audio through the agent's audio processing path # This simulates what happens when a user speaks in a call await agent._reply_to_audio(mia_audio_16khz, test_participant) - + # Test 6: Wait for STT to process and emit transcript # Real STT takes time to process audio and establish connection await asyncio.sleep(5.0) - + # Test 7: Verify that transcript event was emitted - assert len(transcript_events) > 0, \ + assert len(transcript_events) > 0, ( "STT should have emitted at least one transcript event" - + ) + # Test 8: Verify transcript has content first_transcript = transcript_events[0] - assert first_transcript.text is not None, \ - "Transcript should have text content" - assert len(first_transcript.text) > 0, \ - "Transcript text should not be empty" - + assert first_transcript.text is not None, "Transcript should have text content" + assert len(first_transcript.text) > 0, "Transcript text should not be empty" + # Test 9: Verify user metadata is present - assert first_transcript.user_metadata is not None, \ + assert first_transcript.user_metadata is not None, ( "Transcript should have user metadata" - + ) + # Log the transcript for debugging print(f"✅ STT transcribed: '{first_transcript.text}'") - + # Test 10: Clean up await stt.close() await agent.close() diff --git a/tests/test_utils.py b/tests/test_utils.py index 89c3bfa5..abae773a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,89 +7,97 @@ class TestParseInstructions: """Test suite for the parse_instructions function.""" - + def test_parse_instructions_no_mentions(self): """Test parsing text with no @ mentions.""" text = "This is a simple instruction without any mentions." result = parse_instructions(text) - + assert isinstance(result, Instructions) assert result.input_text == text assert result.markdown_contents == {} - + def test_parse_instructions_single_mention(self): """Test parsing text with a single @ mention.""" text = "Please read @nonexistent.md for more information." result = parse_instructions(text) - + assert result.input_text == text assert result.markdown_contents == {"nonexistent.md": ""} # File doesn't exist - + def test_parse_instructions_multiple_mentions(self): """Test parsing text with multiple @ mentions.""" text = "Check @file1.md and @file2.md for details. Also see @guide.md." result = parse_instructions(text) - + assert result.input_text == text - assert result.markdown_contents == {"file1.md": "", "file2.md": "", "guide.md": ""} - + assert result.markdown_contents == { + "file1.md": "", + "file2.md": "", + "guide.md": "", + } + def test_parse_instructions_duplicate_mentions(self): """Test parsing text with duplicate @ mentions.""" text = "Read @nonexistent.md and then @nonexistent.md again." result = parse_instructions(text) - + assert result.input_text == text # Should only include unique filenames assert result.markdown_contents == {"nonexistent.md": ""} - + def test_parse_instructions_non_markdown_mentions(self): """Test parsing text with @ mentions that are not markdown files.""" text = "Check @user123 and @file.txt for information." result = parse_instructions(text) - + assert result.input_text == text # Should only capture .md files assert result.markdown_contents == {} - + def test_parse_instructions_mixed_mentions(self): """Test parsing text with both markdown and non-markdown @ mentions.""" text = "Check @user123, @nonexistent.md, and @config.txt for details." result = parse_instructions(text) - + assert result.input_text == text # Should only capture .md files assert result.markdown_contents == {"nonexistent.md": ""} - + def test_parse_instructions_complex_filenames(self): """Test parsing text with complex markdown filenames.""" text = "See @my-file.md, @file_with_underscores.md, and @file-with-dashes.md." result = parse_instructions(text) - + assert result.input_text == text - assert result.markdown_contents == {"my-file.md": "", "file_with_underscores.md": "", "file-with-dashes.md": ""} - + assert result.markdown_contents == { + "my-file.md": "", + "file_with_underscores.md": "", + "file-with-dashes.md": "", + } + def test_parse_instructions_edge_cases(self): """Test parsing text with edge cases.""" # Empty string result = parse_instructions("") assert result.input_text == "" assert result.markdown_contents == {} - + # Only @ symbol result = parse_instructions("@") assert result.input_text == "@" assert result.markdown_contents == {} - + # @ without filename result = parse_instructions("Check @ for details") assert result.input_text == "Check @ for details" assert result.markdown_contents == {} - + # @ with spaces in filename (should not match) result = parse_instructions("Check @my file.md for details") assert result.input_text == "Check @my file.md for details" assert result.markdown_contents == {} - + def test_parse_instructions_case_sensitivity(self): """Test that @ mentions with different cases are extracted separately.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -97,36 +105,40 @@ def test_parse_instructions_case_sensitivity(self): # because macOS and Windows use case-insensitive filesystems by default file1_path = os.path.join(temp_dir, "Guide.md") file2_path = os.path.join(temp_dir, "Help.md") - - with open(file1_path, 'w', encoding='utf-8') as f: + + with open(file1_path, "w", encoding="utf-8") as f: f.write("# Guide Content") - - with open(file2_path, 'w', encoding='utf-8') as f: + + with open(file2_path, "w", encoding="utf-8") as f: f.write("# Help Content") - + # Test that the parser correctly extracts both case variations from text # even if they refer to the same file on case-insensitive filesystems text = "Check @Guide.md and @guide.md and @Help.md for information." result = parse_instructions(text, base_dir=temp_dir) - + assert result.input_text == text # Parser should extract all mentioned filenames assert "Guide.md" in result.markdown_contents - assert "guide.md" in result.markdown_contents + assert "guide.md" in result.markdown_contents assert "Help.md" in result.markdown_contents # On case-insensitive systems, Guide.md and guide.md will have same content # but the parser still tracks them separately by their @ mention assert len(result.markdown_contents["Guide.md"]) > 0 assert len(result.markdown_contents["Help.md"]) > 0 - + def test_parse_instructions_special_characters(self): """Test parsing with special characters in filenames.""" text = "Check @file-1.md, @file_2.md, and @file.3.md for details." result = parse_instructions(text) - + assert result.input_text == text - assert result.markdown_contents == {"file-1.md": "", "file_2.md": "", "file.3.md": ""} - + assert result.markdown_contents == { + "file-1.md": "", + "file_2.md": "", + "file.3.md": "", + } + def test_parse_instructions_multiline_text(self): """Test parsing multiline text with @ mentions.""" text = """Please review the following files: @@ -135,82 +147,96 @@ def test_parse_instructions_multiline_text(self): - @troubleshooting.md for common issues """ result = parse_instructions(text) - + assert result.input_text == text - assert result.markdown_contents == {"setup.md": "", "api.md": "", "troubleshooting.md": ""} + assert result.markdown_contents == { + "setup.md": "", + "api.md": "", + "troubleshooting.md": "", + } class TestInstructions: """Test suite for the Instructions dataclass.""" - + def test_instructions_initialization(self): """Test Instructions dataclass initialization.""" input_text = "Test instruction" markdown_contents = {"file1.md": "# File 1 content"} - + instructions = Instructions(input_text, markdown_contents) - + assert instructions.input_text == input_text assert instructions.markdown_contents == markdown_contents - + def test_instructions_empty_markdown_files(self): """Test Instructions with empty markdown files dict.""" input_text = "Simple instruction" markdown_contents = {} - + instructions = Instructions(input_text, markdown_contents) - + assert instructions.input_text == input_text assert instructions.markdown_contents == {} - + def test_instructions_equality(self): """Test Instructions equality comparison.""" instructions1 = Instructions("test", {"file.md": "content"}) instructions2 = Instructions("test", {"file.md": "content"}) instructions3 = Instructions("different", {"file.md": "content"}) - + assert instructions1 == instructions2 assert instructions1 != instructions3 class TestParseInstructionsFileReading: """Test suite for file reading functionality in parse_instructions.""" - + def test_parse_instructions_with_existing_files(self): """Test parsing with actual markdown files that exist.""" with tempfile.TemporaryDirectory() as temp_dir: # Create test markdown files file1_path = os.path.join(temp_dir, "readme.md") file2_path = os.path.join(temp_dir, "guide.md") - - with open(file1_path, 'w', encoding='utf-8') as f: + + with open(file1_path, "w", encoding="utf-8") as f: f.write("# README\n\nThis is a test readme file.") - - with open(file2_path, 'w', encoding='utf-8') as f: + + with open(file2_path, "w", encoding="utf-8") as f: f.write("# Guide\n\nThis is a test guide file.") - + text = "Please read @readme.md and @guide.md for information." result = parse_instructions(text, base_dir=temp_dir) - + assert result.input_text == text - assert result.markdown_contents["readme.md"] == "# README\n\nThis is a test readme file." - assert result.markdown_contents["guide.md"] == "# Guide\n\nThis is a test guide file." - + assert ( + result.markdown_contents["readme.md"] + == "# README\n\nThis is a test readme file." + ) + assert ( + result.markdown_contents["guide.md"] + == "# Guide\n\nThis is a test guide file." + ) + def test_parse_instructions_with_mixed_existing_nonexisting_files(self): """Test parsing with mix of existing and non-existing files.""" with tempfile.TemporaryDirectory() as temp_dir: # Create only one test file file1_path = os.path.join(temp_dir, "readme.md") - with open(file1_path, 'w', encoding='utf-8') as f: + with open(file1_path, "w", encoding="utf-8") as f: f.write("# README\n\nThis file exists.") - + text = "Check @readme.md and @nonexistent.md for details." result = parse_instructions(text, base_dir=temp_dir) - + assert result.input_text == text - assert result.markdown_contents["readme.md"] == "# README\n\nThis file exists." - assert result.markdown_contents["nonexistent.md"] == "" # Empty for non-existing file - + assert ( + result.markdown_contents["readme.md"] == "# README\n\nThis file exists." + ) + assert ( + result.markdown_contents["nonexistent.md"] == "" + ) # Empty for non-existing file + def test_parse_instructions_with_custom_base_dir(self): """Test parsing with custom base directory.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -218,70 +244,77 @@ def test_parse_instructions_with_custom_base_dir(self): subdir = os.path.join(temp_dir, "docs") os.makedirs(subdir) file_path = os.path.join(subdir, "api.md") - - with open(file_path, 'w', encoding='utf-8') as f: + + with open(file_path, "w", encoding="utf-8") as f: f.write("# API Documentation\n\nThis is the API docs.") - + text = "See @api.md for API information." result = parse_instructions(text, base_dir=subdir) - + assert result.input_text == text - assert result.markdown_contents["api.md"] == "# API Documentation\n\nThis is the API docs." - + assert ( + result.markdown_contents["api.md"] + == "# API Documentation\n\nThis is the API docs." + ) + def test_parse_instructions_file_read_error_handling(self): """Test handling of file read errors.""" with tempfile.TemporaryDirectory() as temp_dir: # Create a file that will cause read errors (permission issues, etc.) file_path = os.path.join(temp_dir, "readme.md") - with open(file_path, 'w', encoding='utf-8') as f: + with open(file_path, "w", encoding="utf-8") as f: f.write("test content") - + # Make file unreadable (this might not work on all systems) try: os.chmod(file_path, 0o000) # No permissions - + text = "Read @readme.md for information." result = parse_instructions(text, base_dir=temp_dir) - + assert result.input_text == text - assert result.markdown_contents["readme.md"] == "" # Empty due to read error + assert ( + result.markdown_contents["readme.md"] == "" + ) # Empty due to read error finally: # Restore permissions for cleanup os.chmod(file_path, 0o644) - + def test_parse_instructions_unicode_content(self): """Test parsing with unicode content in markdown files.""" with tempfile.TemporaryDirectory() as temp_dir: file_path = os.path.join(temp_dir, "unicode.md") - + # Write unicode content - unicode_content = "# Unicode Test\n\nHello 世界! 🌍\n\nThis has émojis and àccénts." - with open(file_path, 'w', encoding='utf-8') as f: + unicode_content = ( + "# Unicode Test\n\nHello 世界! 🌍\n\nThis has émojis and àccénts." + ) + with open(file_path, "w", encoding="utf-8") as f: f.write(unicode_content) - + text = "Check @unicode.md for unicode content." result = parse_instructions(text, base_dir=temp_dir) - + assert result.input_text == text assert result.markdown_contents["unicode.md"] == unicode_content - + def test_parse_instructions_default_base_dir(self): """Test that default base directory is current working directory.""" with tempfile.TemporaryDirectory() as temp_dir: # Create a test file file_path = os.path.join(temp_dir, "readme.md") - with open(file_path, 'w', encoding='utf-8') as f: + with open(file_path, "w", encoding="utf-8") as f: f.write("# Test readme content") - + # Change to temp directory to test default base_dir original_cwd = os.getcwd() try: os.chdir(temp_dir) - + # This test verifies that when no base_dir is provided, it uses os.getcwd() text = "Read @readme.md for information." result = parse_instructions(text) # No base_dir provided - + assert result.input_text == text # Content will not be empty since readme.md exists in current directory assert "readme.md" in result.markdown_contents @@ -294,91 +327,91 @@ def test_parse_instructions_default_base_dir(self): class TestPcmDataMethods: """Test suite for PcmData class methods.""" - + def test_pcm_data_from_bytes(self): """Test PcmData.from_bytes class method.""" # Create test audio data (1 second of 16kHz audio) test_samples = np.random.randint(-32768, 32767, 16000, dtype=np.int16) audio_bytes = test_samples.tobytes() - + pcm_data = PcmData.from_bytes(audio_bytes, sample_rate=16000, format="s16") - + assert pcm_data.sample_rate == 16000 assert pcm_data.format == "s16" assert np.array_equal(pcm_data.samples, test_samples) assert pcm_data.duration == 1.0 # 1 second - + def test_pcm_data_resample_same_rate(self): """Test resampling when source and target rates are the same.""" test_samples = np.random.randint(-32768, 32767, 16000, dtype=np.int16) pcm_data = PcmData(samples=test_samples, sample_rate=16000, format="s16") - + resampled = pcm_data.resample(target_sample_rate=16000) - + # Should return the same data assert resampled.sample_rate == 16000 assert np.array_equal(resampled.samples, test_samples) assert resampled.format == "s16" - + def test_pcm_data_resample_24khz_to_48khz(self): """Test resampling from 24kHz to 48kHz (Gemini use case).""" # Create test audio data (1 second of 24kHz audio) test_samples = np.random.randint(-32768, 32767, 24000, dtype=np.int16) pcm_data = PcmData(samples=test_samples, sample_rate=24000, format="s16") - + resampled = pcm_data.resample(target_sample_rate=48000) - + assert resampled.sample_rate == 48000 assert resampled.format == "s16" # Should have approximately double the samples (24k -> 48k) assert abs(len(resampled.samples) - 48000) < 100 # Allow some tolerance # Duration should be approximately the same assert abs(resampled.duration - 1.0) < 0.1 - + def test_pcm_data_resample_48khz_to_16khz(self): """Test resampling from 48kHz to 16kHz.""" # Create test audio data (1 second of 48kHz audio) test_samples = np.random.randint(-32768, 32767, 48000, dtype=np.int16) pcm_data = PcmData(samples=test_samples, sample_rate=48000, format="s16") - + resampled = pcm_data.resample(target_sample_rate=16000) - + assert resampled.sample_rate == 16000 assert resampled.format == "s16" # Should have approximately 1/3 the samples (48k -> 16k) assert abs(len(resampled.samples) - 16000) < 100 # Allow some tolerance # Duration should be approximately the same assert abs(resampled.duration - 1.0) < 0.1 - + def test_pcm_data_resample_preserves_metadata(self): """Test that resampling preserves PTS, DTS, and time_base metadata.""" test_samples = np.random.randint(-32768, 32767, 16000, dtype=np.int16) pcm_data = PcmData( - samples=test_samples, - sample_rate=16000, + samples=test_samples, + sample_rate=16000, format="s16", pts=1000, dts=950, - time_base=0.001 + time_base=0.001, ) - + resampled = pcm_data.resample(target_sample_rate=48000) - + assert resampled.pts == 1000 assert resampled.dts == 950 assert resampled.time_base == 0.001 assert abs(resampled.pts_seconds - 1.0) < 0.0001 assert abs(resampled.dts_seconds - 0.95) < 0.0001 - + def test_pcm_data_resample_handles_1d_array(self): """Test that resampling handles 1D arrays correctly (fixes ndim error).""" # Create test audio data (1 second of 24kHz audio) - 1D array test_samples = np.random.randint(-32768, 32767, 24000, dtype=np.int16) pcm_data = PcmData(samples=test_samples, sample_rate=24000, format="s16") - + # This should now work without the ndim error resampled = pcm_data.resample(target_sample_rate=48000) - + assert resampled.sample_rate == 48000 assert resampled.format == "s16" # Should have approximately double the samples (24k -> 48k) @@ -387,16 +420,16 @@ def test_pcm_data_resample_handles_1d_array(self): assert abs(resampled.duration - 1.0) < 0.1 # Output should be 1D array assert resampled.samples.ndim == 1 - + def test_pcm_data_resample_handles_2d_array(self): """Test that resampling handles 2D arrays correctly.""" # Create test audio data (1 second of 24kHz audio) - 2D array (channels, samples) test_samples = np.random.randint(-32768, 32767, (1, 24000), dtype=np.int16) pcm_data = PcmData(samples=test_samples, sample_rate=24000, format="s16") - + # This should work with 2D arrays too resampled = pcm_data.resample(target_sample_rate=48000) - + assert resampled.sample_rate == 48000 assert resampled.format == "s16" # Should have approximately double the samples (24k -> 48k) @@ -405,17 +438,17 @@ def test_pcm_data_resample_handles_2d_array(self): assert abs(resampled.duration - 1.0) < 0.1 # Output should be 1D array (flattened) assert resampled.samples.ndim == 1 - + def test_pcm_data_from_bytes_and_resample_chain(self): """Test chaining from_bytes and resample methods (Gemini use case).""" # Create test audio data (1 second of 24kHz audio) test_samples = np.random.randint(-32768, 32767, 24000, dtype=np.int16) audio_bytes = test_samples.tobytes() - + # Chain the methods like in realtime2.py pcm_data = PcmData.from_bytes(audio_bytes, sample_rate=24000, format="s16") resampled_pcm = pcm_data.resample(target_sample_rate=48000) - + assert pcm_data.sample_rate == 24000 assert resampled_pcm.sample_rate == 48000 assert resampled_pcm.format == "s16" @@ -423,16 +456,18 @@ def test_pcm_data_from_bytes_and_resample_chain(self): assert abs(len(resampled_pcm.samples) - 48000) < 100 # Allow some tolerance # Duration should be approximately the same assert abs(resampled_pcm.duration - 1.0) < 0.1 - + def test_pcm_data_resample_av_array_shape_fix(self): """Test that fixes the AV library array shape error (channels, samples).""" # Create test audio data that would cause the "Expected packed array.shape[0] to equal 1" error - test_samples = np.random.randint(-32768, 32767, 1920, dtype=np.int16) # Small chunk like in the error + test_samples = np.random.randint( + -32768, 32767, 1920, dtype=np.int16 + ) # Small chunk like in the error pcm_data = PcmData(samples=test_samples, sample_rate=24000, format="s16") - + # This should work without the array shape error resampled = pcm_data.resample(target_sample_rate=48000) - + assert resampled.sample_rate == 48000 assert resampled.format == "s16" # Should have approximately double the samples (1920 -> ~3840) @@ -442,4 +477,3 @@ def test_pcm_data_resample_av_array_shape_fix(self): # Shared fixtures for integration tests -