Skip to content

Commit 4eae232

Browse files
committed
Add voice agent bridge
1 parent a46a4c9 commit 4eae232

File tree

1 file changed

+306
-0
lines changed

1 file changed

+306
-0
lines changed
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
#!/usr/bin/env python3
2+
"""
3+
ROS2 Bridge Node for LiveKit Voice Agent
4+
5+
This node acts as a bridge between the LiveKit voice agent (running as a separate process)
6+
and the ROS2 ecosystem. It communicates with the voice agent via WebSocket and provides
7+
ROS2 topics and services for integration with other robot components.
8+
"""
9+
10+
import json
11+
import asyncio
12+
import threading
13+
from typing import Optional
14+
15+
import rclpy
16+
from rclpy.node import Node
17+
from rclpy.executors import MultiThreadedExecutor
18+
from rclpy.callback_groups import ReentrantCallbackGroup
19+
from std_msgs.msg import String, Bool
20+
from geometry_msgs.msg import Twist
21+
22+
try:
23+
import websockets
24+
import websockets.client
25+
except ImportError:
26+
print("websockets not available - bridge will not function")
27+
websockets = None
28+
29+
30+
class VoiceAgentBridge(Node):
31+
"""ROS2 Bridge Node for LiveKit Voice Agent Communication"""
32+
33+
def __init__(self):
34+
super().__init__('voice_agent_bridge')
35+
36+
# Parameters
37+
self.declare_parameter('voice_agent_host', 'localhost')
38+
self.declare_parameter('voice_agent_port', 8080)
39+
self.declare_parameter('reconnect_interval', 5.0)
40+
41+
self.host = self.get_parameter('voice_agent_host').value
42+
self.port = self.get_parameter('voice_agent_port').value
43+
self.reconnect_interval = self.get_parameter('reconnect_interval').value
44+
45+
# WebSocket connection
46+
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
47+
self.connection_active = False
48+
self.reconnect_task: Optional[asyncio.Task] = None
49+
50+
# Callback group for async operations
51+
self.callback_group = ReentrantCallbackGroup()
52+
53+
# ROS2 Publishers (Voice Agent → ROS2)
54+
self.state_pub = self.create_publisher(
55+
String,
56+
'voice_agent/state',
57+
10,
58+
callback_group=self.callback_group
59+
)
60+
61+
self.conversation_pub = self.create_publisher(
62+
String,
63+
'voice_agent/conversation',
64+
10,
65+
callback_group=self.callback_group
66+
)
67+
68+
self.emotion_pub = self.create_publisher(
69+
String,
70+
'voice_agent/emotion',
71+
10,
72+
callback_group=self.callback_group
73+
)
74+
75+
self.connected_pub = self.create_publisher(
76+
Bool,
77+
'voice_agent/connected',
78+
10,
79+
callback_group=self.callback_group
80+
)
81+
82+
# ROS2 Subscribers (ROS2 → Voice Agent)
83+
self.virtual_request_sub = self.create_subscription(
84+
String,
85+
'voice_agent/virtual_requests',
86+
self.handle_virtual_request,
87+
10,
88+
callback_group=self.callback_group
89+
)
90+
91+
self.command_sub = self.create_subscription(
92+
String,
93+
'voice_agent/commands',
94+
self.handle_command,
95+
10,
96+
callback_group=self.callback_group
97+
)
98+
99+
# Start WebSocket connection in separate thread
100+
self.websocket_thread = threading.Thread(target=self._run_websocket_client, daemon=True)
101+
self.websocket_thread.start()
102+
103+
# Status timer
104+
self.status_timer = self.create_timer(1.0, self.publish_connection_status)
105+
106+
self.get_logger().info(f"Voice Agent Bridge initialized - connecting to {self.host}:{self.port}")
107+
108+
def _run_websocket_client(self):
109+
"""Run WebSocket client in separate thread with its own event loop"""
110+
loop = asyncio.new_event_loop()
111+
asyncio.set_event_loop(loop)
112+
113+
try:
114+
loop.run_until_complete(self._maintain_connection())
115+
except Exception as e:
116+
self.get_logger().error(f"WebSocket client error: {e}")
117+
finally:
118+
loop.close()
119+
120+
async def _maintain_connection(self):
121+
"""Maintain WebSocket connection with automatic reconnection"""
122+
while rclpy.ok():
123+
try:
124+
uri = f"ws://{self.host}:{self.port}"
125+
self.get_logger().info(f"Attempting to connect to voice agent at {uri}")
126+
127+
async with websockets.client.connect(uri) as websocket:
128+
self.websocket = websocket
129+
self.connection_active = True
130+
self.get_logger().info("Connected to voice agent WebSocket")
131+
132+
# Listen for messages from voice agent
133+
async for message in websocket:
134+
await self._handle_websocket_message(message)
135+
136+
except websockets.exceptions.ConnectionClosed:
137+
self.get_logger().warn("WebSocket connection closed")
138+
except Exception as e:
139+
self.get_logger().error(f"WebSocket connection error: {e}")
140+
finally:
141+
self.connection_active = False
142+
self.websocket = None
143+
144+
if rclpy.ok():
145+
self.get_logger().info(f"Reconnecting in {self.reconnect_interval} seconds...")
146+
await asyncio.sleep(self.reconnect_interval)
147+
148+
async def _handle_websocket_message(self, message: str):
149+
"""Handle incoming messages from voice agent"""
150+
try:
151+
data = json.loads(message)
152+
message_type = data.get('type', 'unknown')
153+
154+
if message_type == 'STATE_CHANGE':
155+
# Publish agent state change
156+
state_msg = String()
157+
state_msg.data = json.dumps({
158+
'state': data.get('state'),
159+
'timestamp': data.get('timestamp'),
160+
'previous_state': data.get('previous_state')
161+
})
162+
self.state_pub.publish(state_msg)
163+
164+
elif message_type == 'CONVERSATION':
165+
# Publish conversation transcript
166+
conv_msg = String()
167+
conv_msg.data = json.dumps({
168+
'role': data.get('role'),
169+
'text': data.get('text'),
170+
'timestamp': data.get('timestamp')
171+
})
172+
self.conversation_pub.publish(conv_msg)
173+
174+
elif message_type == 'EMOTION':
175+
# Publish emotion change
176+
emotion_msg = String()
177+
emotion_msg.data = json.dumps({
178+
'emotion': data.get('emotion'),
179+
'previous_emotion': data.get('previous_emotion'),
180+
'timestamp': data.get('timestamp')
181+
})
182+
self.emotion_pub.publish(emotion_msg)
183+
184+
elif message_type == 'STATUS':
185+
# Handle status updates
186+
self.get_logger().info(f"Voice agent status: {data.get('message', 'Unknown')}")
187+
188+
else:
189+
self.get_logger().warn(f"Unknown message type from voice agent: {message_type}")
190+
191+
except json.JSONDecodeError as e:
192+
self.get_logger().error(f"Invalid JSON from voice agent: {e}")
193+
except Exception as e:
194+
self.get_logger().error(f"Error handling voice agent message: {e}")
195+
196+
def handle_virtual_request(self, msg: String):
197+
"""Handle virtual request from ROS2 and forward to voice agent"""
198+
try:
199+
# Parse the ROS2 message
200+
request_data = json.loads(msg.data)
201+
202+
# Format for voice agent
203+
command = {
204+
'type': 'VIRTUAL_REQUEST',
205+
'request_type': request_data.get('request_type', 'NEW_COFFEE_REQUEST'),
206+
'content': request_data.get('content', ''),
207+
'priority': request_data.get('priority', 'normal'),
208+
'timestamp': request_data.get('timestamp')
209+
}
210+
211+
# Send to voice agent
212+
asyncio.run_coroutine_threadsafe(
213+
self._send_to_voice_agent(command),
214+
self.websocket_thread._target.__globals__.get('loop') if hasattr(self.websocket_thread, '_target') else None
215+
)
216+
217+
self.get_logger().info(f"Forwarded virtual request: {request_data.get('request_type')}")
218+
219+
except json.JSONDecodeError as e:
220+
self.get_logger().error(f"Invalid JSON in virtual request: {e}")
221+
except Exception as e:
222+
self.get_logger().error(f"Error handling virtual request: {e}")
223+
224+
def handle_command(self, msg: String):
225+
"""Handle command from ROS2 and forward to voice agent"""
226+
try:
227+
# Parse the command
228+
command_data = json.loads(msg.data)
229+
230+
# Format for voice agent
231+
command = {
232+
'type': 'COMMAND',
233+
'action': command_data.get('action'),
234+
'parameters': command_data.get('parameters', {}),
235+
'timestamp': command_data.get('timestamp')
236+
}
237+
238+
# Send to voice agent
239+
asyncio.run_coroutine_threadsafe(
240+
self._send_to_voice_agent(command),
241+
self.websocket_thread._target.__globals__.get('loop') if hasattr(self.websocket_thread, '_target') else None
242+
)
243+
244+
self.get_logger().info(f"Forwarded command: {command_data.get('action')}")
245+
246+
except json.JSONDecodeError as e:
247+
self.get_logger().error(f"Invalid JSON in command: {e}")
248+
except Exception as e:
249+
self.get_logger().error(f"Error handling command: {e}")
250+
251+
async def _send_to_voice_agent(self, data: dict):
252+
"""Send data to voice agent via WebSocket"""
253+
if self.websocket and self.connection_active:
254+
try:
255+
message = json.dumps(data)
256+
await self.websocket.send(message)
257+
except Exception as e:
258+
self.get_logger().error(f"Error sending to voice agent: {e}")
259+
else:
260+
self.get_logger().warn("Cannot send to voice agent - not connected")
261+
262+
def publish_connection_status(self):
263+
"""Publish connection status periodically"""
264+
status_msg = Bool()
265+
status_msg.data = self.connection_active
266+
self.connected_pub.publish(status_msg)
267+
268+
def destroy_node(self):
269+
"""Clean up resources"""
270+
self.connection_active = False
271+
if hasattr(self, 'websocket_thread') and self.websocket_thread.is_alive():
272+
# Give time for graceful shutdown
273+
self.websocket_thread.join(timeout=2.0)
274+
super().destroy_node()
275+
276+
277+
def main(args=None):
278+
"""Main entry point for the bridge node"""
279+
if websockets is None:
280+
print("ERROR: websockets package not available. Install with: pip install websockets")
281+
return
282+
283+
rclpy.init(args=args)
284+
285+
try:
286+
bridge_node = VoiceAgentBridge()
287+
288+
# Use MultiThreadedExecutor to handle async operations
289+
executor = MultiThreadedExecutor()
290+
executor.add_node(bridge_node)
291+
292+
try:
293+
executor.spin()
294+
except KeyboardInterrupt:
295+
bridge_node.get_logger().info("Shutting down voice agent bridge...")
296+
finally:
297+
bridge_node.destroy_node()
298+
299+
except Exception as e:
300+
print(f"Error starting voice agent bridge: {e}")
301+
finally:
302+
rclpy.shutdown()
303+
304+
305+
if __name__ == '__main__':
306+
main()

0 commit comments

Comments
 (0)