Skip to content

Commit 5532c2f

Browse files
authored
refactor: connector api (#532)
1 parent c4ba634 commit 5532c2f

File tree

35 files changed

+393
-458
lines changed

35 files changed

+393
-458
lines changed

docs/developer_guide/tools.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,14 @@ Refer to the [BaseROS2Tool source code](../../src/rai_core/rai/tools/ros2/base.p
152152
Tools can be initialized with parameters such as a connector, enabling custom configurations for ROS 2 environments.
153153

154154
```python
155-
from rai.communication.ros2 import ROS2ARIConnector
155+
from rai.communication.ros2 import ROS2Connector
156156
from rai.tools.ros2 import (
157157
GetROS2ImageTool,
158158
GetROS2TopicsNamesAndTypesTool,
159159
PublishROS2MessageTool,
160160
)
161161

162-
def initialize_tools(connector: ROS2ARIConnector):
162+
def initialize_tools(connector: ROS2Connector):
163163
"""Initialize and configure ROS 2 tools.
164164
165165
Returns:
@@ -193,7 +193,7 @@ TODO(docs): add link to the BaseAgent docs (regarding distributed setup)
193193

194194
```python
195195
from rai.agents import ReActAgent
196-
from rai.communication import ROS2ARIConnector, ROS2HRIConnector
196+
from rai.communication import ROS2Connector, ROS2HRIConnector
197197
from rai.tools.ros2 import ROS2Toolkit
198198
from rai.communication.ros2 import ROS2Context
199199
from rai import AgentRunner
@@ -202,10 +202,10 @@ from rai import AgentRunner
202202
def main() -> None:
203203
"""Initialize and run the RAI agent with configured tools."""
204204
connector = ROS2HRIConnector(sources=["/from_human"], targets=["/to_human"])
205-
ari_connector = ROS2ARIConnector()
205+
ros2_connector = ROS2Connector()
206206
agent = ReActAgent(
207207
connectors={"hri": connector},
208-
tools=initialize_tools(connector=ari_connector),
208+
tools=initialize_tools(connector=ros2_connector),
209209
)
210210
runner = AgentRunner([agent])
211211
runner.run_and_wait_for_shutdown()
@@ -226,9 +226,9 @@ from rai.communication.ros2 import ROS2Context
226226

227227
@ROS2Context()
228228
def main():
229-
ari_connector = ROS2ARIConnector()
229+
ros2_connector = ROS2Connector()
230230
agent = create_react_runnable(
231-
tools=initialize_tools(connector=ari_connector),
231+
tools=initialize_tools(connector=ros2_connector),
232232
system_prompt="You are a helpful assistant that can answer questions and help with tasks.",
233233
)
234234
state = {'messages': []}
@@ -244,5 +244,5 @@ def main():
244244
## Related Topics
245245

246246
- [Connectors](../communication/connectors.md)
247-
- [ROS2ARIConnector](../communication/ros2.md)
247+
- [ROS2Connector](../communication/ros2.md)
248248
- [ROS2HRIConnector](../communication/ros2.md)

examples/agents/react.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
# limitations under the License.
1414

1515
from rai.agents import AgentRunner, ReActAgent
16-
from rai.communication.ros2 import ROS2ARIConnector, ROS2Context, ROS2HRIConnector
16+
from rai.communication.ros2 import ROS2Connector, ROS2Context, ROS2HRIConnector
1717
from rai.tools.ros2 import ROS2Toolkit
1818

1919

2020
@ROS2Context()
2121
def main():
2222
connector = ROS2HRIConnector(sources=["/from_human"], targets=["/to_human"])
23-
ari_connector = ROS2ARIConnector()
23+
ros2_connector = ROS2Connector()
2424
agent = ReActAgent(
2525
connectors={"hri": connector},
26-
tools=ROS2Toolkit(connector=ari_connector).get_tools(),
26+
tools=ROS2Toolkit(connector=ros2_connector).get_tools(),
2727
) # type: ignore
2828
runner = AgentRunner([agent])
2929
runner.run_and_wait_for_shutdown()

examples/agents/streamlit_react.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@
1818
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
1919
from rai.agents.integrations.streamlit import get_streamlit_cb, streamlit_invoke
2020
from rai.agents.react_agent import ReActAgent
21-
from rai.communication.ros2 import ROS2ARIConnector
21+
from rai.communication.ros2 import ROS2Connector
2222
from rai.messages import HumanMultimodalMessage
2323
from rai.tools.ros2 import ROS2Toolkit
2424

2525

2626
@st.cache_resource
2727
def initialize_graph():
2828
rclpy.init()
29-
ari_connector = ROS2ARIConnector()
30-
tools = ROS2Toolkit(connector=ari_connector).get_tools()
29+
ros2_connector = ROS2Connector()
30+
tools = ROS2Toolkit(connector=ros2_connector).get_tools()
3131
agent = ReActAgent(connectors={}, tools=tools).agent
3232
return agent
3333

examples/agriculture-demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from langchain_core.runnables import Runnable
2020
from rai import get_llm_model
2121
from rai.agents.conversational_agent import State, create_conversational_agent
22-
from rai.communication.ros2.connectors import ROS2ARIConnector
22+
from rai.communication.ros2.connectors import ROS2Connector
2323
from rai.tools.ros2 import ROS2ServicesToolkit, ROS2TopicsToolkit
2424
from rai.tools.time import WaitForSecondsTool
2525
from rclpy.callback_groups import ReentrantCallbackGroup
@@ -103,7 +103,7 @@ def main():
103103
104104
Important: You must call only one service. The tractor can only handle one service call.
105105
"""
106-
connector = ROS2ARIConnector()
106+
connector = ROS2Connector()
107107
agent = create_conversational_agent(
108108
llm=get_llm_model("complex_model"),
109109
system_prompt=SYSTEM_PROMPT,

examples/manipulation-demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
from langchain_core.messages import HumanMessage
1919
from rai import get_llm_model
2020
from rai.agents.conversational_agent import create_conversational_agent
21-
from rai.communication.ros2.connectors import ROS2ARIConnector
21+
from rai.communication.ros2.connectors import ROS2Connector
2222
from rai.tools.ros2 import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool
2323
from rai.tools.ros2.manipulation import GetObjectPositionsTool, MoveToPointTool
2424
from rai_open_set_vision.tools import GetGrabbingPointTool
2525

2626

2727
def create_agent():
2828
rclpy.init()
29-
connector = ROS2ARIConnector()
29+
connector = ROS2Connector()
3030
node = connector.node
3131
node.declare_parameter("conversion_ratio", 1.0)
3232

examples/rosbot-xl-demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from langchain_core.tools import BaseTool
2020
from rai import get_llm_model
2121
from rai.agents import ReActAgent
22-
from rai.communication.ros2 import ROS2ARIConnector
22+
from rai.communication.ros2 import ROS2Connector
2323
from rai.frontend.streamlit import run_streamlit_app
2424
from rai.tools.ros2 import (
2525
GetObjectPositionsTool,
@@ -56,7 +56,7 @@ def initialize_agent():
5656
Your job is to transform user intent into meaningful, goal-driven behavior within the physical world.
5757
"""
5858

59-
connector = ROS2ARIConnector()
59+
connector = ROS2Connector()
6060
tools: List[BaseTool] = [
6161
GetROS2TransformConfiguredTool(
6262
connector=connector,

src/rai_asr/rai_asr/agents/asr_agent.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
from numpy.typing import NDArray
2424
from rai.agents.base import BaseAgent
2525
from rai.communication.ros2 import (
26-
ROS2ARIConnector,
27-
ROS2ARIMessage,
26+
ROS2Connector,
2827
ROS2HRIConnector,
2928
ROS2HRIMessage,
29+
ROS2Message,
3030
)
3131
from rai.communication.sound_device import (
3232
SoundDeviceConfig,
@@ -81,11 +81,11 @@ def __init__(
8181
targets=[], sources=[("microphone", microphone_config)]
8282
)
8383
ros2_hri_connector = ROS2HRIConnector(ros2_name, targets=["/from_human"])
84-
ros2_ari_connector = ROS2ARIConnector(ros2_name + "ari")
84+
ros2_connector = ROS2Connector(ros2_name + "ari")
8585
self.connectors = {
8686
"microphone": microphone,
8787
"ros2_hri": ros2_hri_connector,
88-
"ros2_ari": ros2_ari_connector,
88+
"ros2": ros2_connector,
8989
}
9090
super().__init__()
9191
self.should_record_pipeline: List[BaseVoiceDetectionModel] = []
@@ -266,9 +266,9 @@ def _transcription_thread(self, identifier: str):
266266
def _send_ros2_message(self, data: str, topic: str):
267267
self.logger.debug(f"Sending message to {topic}: {data}")
268268
if topic == "/voice_commands":
269-
msg = ROS2ARIMessage({"data": data})
269+
msg = ROS2Message({"data": data})
270270
try:
271-
self.connectors["ros2_ari"].send_message(
271+
self.connectors["ros2"].send_message(
272272
msg, topic, msg_type="std_msgs/msg/String"
273273
)
274274
except Exception as e:

src/rai_bench/rai_bench/examples/o3de_test_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import rclpy
2323
from langchain.tools import BaseTool
2424
from rai.agents.conversational_agent import create_conversational_agent
25-
from rai.communication.ros2.connectors import ROS2ARIConnector
25+
from rai.communication.ros2.connectors import ROS2Connector
2626
from rai.initialization import get_llm_model
2727
from rai.tools.ros2 import (
2828
GetObjectPositionsTool,
@@ -46,7 +46,7 @@
4646

4747
if __name__ == "__main__":
4848
rclpy.init()
49-
connector = ROS2ARIConnector()
49+
connector = ROS2Connector()
5050
node = connector.node
5151
node.declare_parameter("conversion_ratio", 1.0)
5252

src/rai_bench/rai_bench/tool_calling_agent_bench/mocked_tools.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
import numpy as np
1919
import numpy.typing as npt
20-
from rai.communication.ros2.connectors import ROS2ARIConnector
21-
from rai.communication.ros2.messages import ROS2ARIMessage
20+
from rai.communication.ros2.connectors import ROS2Connector
21+
from rai.communication.ros2.messages import ROS2Message
2222
from rai.messages import MultimodalArtifact, preprocess_image
2323
from rai.tools.ros2 import (
2424
GetObjectPositionsTool,
@@ -31,7 +31,7 @@
3131

3232

3333
class MockGetROS2TopicsNamesAndTypesTool(GetROS2TopicsNamesAndTypesTool):
34-
connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
34+
connector: ROS2Connector = MagicMock(spec=ROS2Connector)
3535
mock_topics_names_and_types: list[str]
3636

3737
def _run(self) -> str:
@@ -46,7 +46,7 @@ def _run(self) -> str:
4646

4747

4848
class MockGetROS2ImageTool(GetROS2ImageTool):
49-
connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
49+
connector: ROS2Connector = MagicMock(spec=ROS2Connector)
5050
expected_topics: List[str]
5151

5252
def _run(
@@ -95,7 +95,7 @@ def generate_mock_image() -> npt.NDArray[np.uint8]:
9595

9696

9797
class MockReceiveROS2MessageTool(ReceiveROS2MessageTool):
98-
connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
98+
connector: ROS2Connector = MagicMock(spec=ROS2Connector)
9999
expected_topics: List[str]
100100

101101
def _run(self, topic: str) -> str:
@@ -120,14 +120,14 @@ def _run(self, topic: str) -> str:
120120
raise ValueError(
121121
f"Topic {topic} is not available within 1.0 seconds. Check if the topic exists."
122122
)
123-
message: ROS2ARIMessage = MagicMock(spec=ROS2ARIMessage)
123+
message: ROS2Message = MagicMock(spec=ROS2Message)
124124
message.payload = {"mock": "payload"}
125125
message.metadata = {"mock": "metadata"}
126126
return str({"payload": message.payload, "metadata": message.metadata})
127127

128128

129129
class MockMoveToPointTool(MoveToPointTool):
130-
connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
130+
connector: ROS2Connector = MagicMock(spec=ROS2Connector)
131131

132132
def _run(self, x: float, y: float, z: float, task: str) -> str:
133133
"""Method that return a mock message with the end effector position.
@@ -149,7 +149,7 @@ def _run(self, x: float, y: float, z: float, task: str) -> str:
149149

150150

151151
class MockGetObjectPositionsTool(GetObjectPositionsTool):
152-
connector: ROS2ARIConnector = MagicMock(spec=ROS2ARIConnector)
152+
connector: ROS2Connector = MagicMock(spec=ROS2Connector)
153153

154154
# Create mock instances for the arguments
155155
target_frame: str = MagicMock(spec=str)

src/rai_core/rai/communication/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .ari_connector import ARIConnector, ARIMessage
1615
from .base_connector import BaseConnector, BaseMessage
1716
from .hri_connector import HRIConnector, HRIMessage
1817

1918
__all__ = [
20-
"ARIConnector",
21-
"ARIMessage",
2219
"BaseConnector",
2320
"BaseMessage",
2421
"HRIConnector",

0 commit comments

Comments
 (0)