-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathagents.py
More file actions
150 lines (125 loc) · 4.91 KB
/
agents.py
File metadata and controls
150 lines (125 loc) · 4.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import warnings
from typing import Dict, Callable, Union, TYPE_CHECKING
from abc import ABCMeta, abstractmethod
from core.data import data_packet
from mangrove import (
VADStage,
STTStage,
BotStage,
TTSStage,
)
from storage_manager import StorageManager
from core import AudioBuffer, DataPacket, AudioPacket, TextPacket
from core.stage import PipelineSequence, PipelineStage
from core.utils import logger
if TYPE_CHECKING:
from host import SocketIONamespace
# TODO check on this later
warnings.filterwarnings("ignore", category=UserWarning)
class TextBasedAgentPipeline(PipelineSequence):
"""Pipeline for text-based agent processing."""
input_type = TextPacket
output_type = TextPacket
class VoiceCapableAgentPipeline(PipelineSequence):
"""Pipeline for voice-capable agent processing."""
input_type = AudioPacket
output_type = AudioPacket
def on_start(self):
super().on_start()
self.session_audio_buffer = AudioBuffer()
def on_connect(self):
logger.info("Connected to the server.")
if self.startup_audiopacket:
from copy import deepcopy
self._host.emit_bot_voice(deepcopy(self.startup_audiopacket))
logger.info("Ready to receive audio packets.")
def on_disconnect(self):
"""Clean up upon disconnection"""
logger.info("Disconnected from the server.")
if self.session_audio_buffer.is_empty():
return
StorageManager.write_audio_file(self.session_audio_buffer.dump_to_packet())
StorageManager.ensure_completion()
logger.info("Session completed.")
class Agent(metaclass=ABCMeta):
"""Base class for all agents."""
def __init__(self):
"""Base class for all agents."""
self.name = self.__class__.__name__
def on_start(self):
"""Called when the agent is started."""
logger.info(f"{self.name} agent started.")
def on_connect(self):
"""Called when the agent connects to the server."""
logger.info(f"{self.name} agent connected.")
def on_disconnect(self):
"""Called when the agent disconnects from the server."""
logger.info(f"{self.name} agent disconnected.")
@abstractmethod
def feed(self, data_packet: DataPacket):
"""Feed a data packet to the agent."""
raise NotImplementedError("This method should be implemented by subclasses.")
@abstractmethod
def start(self, host):
"""Start the agent with the given host."""
raise NotImplementedError("This method should be implemented by subclasses.")
class BasicConversationalAgent(Agent):
"""Agent controller for the conversational AI server."""
def __init__(
self,
text_only: bool = False,
device=None,
endpoints: Dict[str, str] = {
"bot": "openai",
"tts": "gtts",
},
persona_configs: Union[str, Dict] = None,
welcome_msg: str="Welcome, AI server connection is succesful.",
verbose=False,
):
super().__init__()
bot = BotStage(name="bot", endpoint=endpoints["bot"], persona_configs=persona_configs, verbose=verbose)
if not text_only:
vad = VADStage(name="vad", device=device)
stt = STTStage(name="stt", device=device)
tts = TTSStage(name="tts", endpoint=endpoints["tts"])
self.startup_audiopacket = None
# if welcome_msg:
# self.startup_audiopacket = tts.read(
# welcome_msg,
# as_generator=False
# )
if text_only:
self._pipeline: TextBasedAgentPipeline = TextBasedAgentPipeline(
name="text_based_agent_pipeline",
stages=[
bot,
],
verbose=verbose,
)
else:
self._pipeline: VoiceCapableAgentPipeline = VoiceCapableAgentPipeline(
name="voice_capable_agent_pipeline",
stages=[
vad,
stt,
bot,
tts
],
verbose=verbose,
)
self._text_only = text_only
def start(self, host: "SocketIONamespace"):
"""Start the agent with the given host."""
self.host = host
self._pipeline.response_emission_mapping = {
"stt": self.host.emit_stt_response,
"bot": self.host.emit_bot_response,
"tts": self.host.emit_bot_voice,
}
self._pipeline.start(host=self.host)
def feed(self, data_packet: DataPacket):
"""Feed a data packet to the appropriate agent pipeline."""
if not isinstance(data_packet, self._pipeline.input_type):
raise ValueError(f"Cannot feed data packet of type {type(data_packet)} to the agent pipeline {self._pipeline.name}. Expected type {self._pipeline.input_type}.")
self._pipeline.feed(data_packet)