Skip to content

Commit c4ba634

Browse files
refactor: remove ROS 2 imports from non ROS 2 files in rai_core (#528)
Co-authored-by: Bartłomiej Boczek <[email protected]>
1 parent dd34d5e commit c4ba634

File tree

11 files changed

+41
-66
lines changed

11 files changed

+41
-66
lines changed

src/rai_core/rai/agents/conversational_agent.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,23 @@
1515

1616
import logging
1717
from functools import partial
18-
from typing import List, Optional, TypedDict, Union
18+
from typing import List, Optional, TypedDict
1919

2020
from langchain.chat_models.base import BaseChatModel
2121
from langchain_core.messages import BaseMessage, SystemMessage
2222
from langchain_core.tools import BaseTool
2323
from langgraph.graph import START, StateGraph
2424
from langgraph.graph.state import CompiledStateGraph
2525
from langgraph.prebuilt.tool_node import tools_condition
26-
from rclpy.impl.rcutils_logger import RcutilsLogger
2726

2827
from rai.agents.tool_runner import ToolRunner
2928

30-
loggers_type = Union[RcutilsLogger, logging.Logger]
31-
3229

3330
class State(TypedDict):
3431
messages: List[BaseMessage]
3532

3633

37-
def agent(llm: BaseChatModel, logger: loggers_type, system_prompt: str, state: State):
34+
def agent(llm: BaseChatModel, logger: logging.Logger, system_prompt: str, state: State):
3835
logger.info("Running thinker")
3936

4037
# If there are no messages, do nothing
@@ -53,7 +50,7 @@ def create_conversational_agent(
5350
llm: BaseChatModel,
5451
tools: List[BaseTool],
5552
system_prompt: str,
56-
logger: Optional[RcutilsLogger | logging.Logger] = None,
53+
logger: Optional[logging.Logger] = None,
5754
debug=False,
5855
) -> CompiledStateGraph:
5956
_logger = None

src/rai_core/rai/agents/state_based.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,10 @@
2727
from langgraph.graph.graph import CompiledGraph
2828
from langgraph.prebuilt.tool_node import msg_content_output
2929
from pydantic import BaseModel, Field, ValidationError
30-
from rclpy.impl.rcutils_logger import RcutilsLogger
3130

3231
from rai.agents.tool_runner import ToolRunner
3332
from rai.messages import HumanMultimodalMessage
3433

35-
loggers_type = Union[RcutilsLogger, logging.Logger]
36-
3734

3835
class State(TypedDict):
3936
messages: List[BaseMessage]
@@ -68,7 +65,7 @@ def tools_condition(
6865
return "reporter"
6966

7067

71-
def thinker(llm: BaseChatModel, logger: loggers_type, state: State):
68+
def thinker(llm: BaseChatModel, logger: logging.Logger, state: State):
7269
logger.info("Running thinker")
7370
prompt = (
7471
"Based on the data provided, reason about the situation. "
@@ -81,7 +78,7 @@ def thinker(llm: BaseChatModel, logger: loggers_type, state: State):
8178

8279

8380
def decider(
84-
llm: Runnable[LanguageModelInput, BaseMessage], logger: loggers_type, state: State
81+
llm: Runnable[LanguageModelInput, BaseMessage], logger: logging.Logger, state: State
8582
):
8683
logger.info("Running decider")
8784
prompt = (
@@ -98,7 +95,7 @@ def decider(
9895
return state
9996

10097

101-
def reporter(llm: BaseChatModel, logger: loggers_type, state: State):
98+
def reporter(llm: BaseChatModel, logger: logging.Logger, state: State):
10299
logger.info("Summarizing the conversation")
103100
prompt = (
104101
"You are the reporter. Your task is to summarize what happened previously. "
@@ -126,7 +123,7 @@ def reporter(llm: BaseChatModel, logger: loggers_type, state: State):
126123

127124

128125
def retriever_wrapper(
129-
state_retriever: Callable[[], Dict[str, Any]], logger: loggers_type, state: State
126+
state_retriever: Callable[[], Dict[str, Any]], logger: logging.Logger, state: State
130127
):
131128
"""This wrapper is used to retrieve multimodal information from the output of state_retriever."""
132129
ts = time.perf_counter()
@@ -150,10 +147,10 @@ def create_state_based_agent(
150147
llm: BaseChatModel,
151148
tools: List[BaseTool],
152149
state_retriever: Callable[[], Dict[str, Any]],
153-
logger: Optional[RcutilsLogger | logging.Logger] = None,
150+
logger: Optional[logging.Logger] = None,
154151
) -> CompiledGraph:
155152
_logger = None
156-
if isinstance(logger, RcutilsLogger):
153+
if isinstance(logger, logging.Logger):
157154
_logger = logger
158155
else:
159156
_logger = logging.getLogger(__name__)

src/rai_core/rai/agents/tool_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from langgraph.prebuilt.tool_node import msg_content_output
2727
from langgraph.utils.runnable import RunnableCallable
2828
from pydantic import ValidationError
29-
from rclpy.impl.rcutils_logger import RcutilsLogger
3029

3130
from rai.messages import MultimodalArtifact, ToolMultimodalMessage, store_artifacts
3231

@@ -38,7 +37,7 @@ def __init__(
3837
*,
3938
name: str = "tools",
4039
tags: Optional[list[str]] = None,
41-
logger: Optional[Union[RcutilsLogger, logging.Logger]] = None,
40+
logger: Optional[logging.Logger] = None,
4241
) -> None:
4342
super().__init__(self._func, name=name, tags=tags, trace=False)
4443
self.logger = logger or logging.getLogger(__name__)

src/rai_core/rai/communication/ros2/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import importlib.util
16+
17+
if importlib.util.find_spec("rclpy") is None:
18+
raise ImportError(
19+
"This is a ROS2 feature. Make sure ROS2 is installed and sourced."
20+
)
21+
1522
from .api import (
1623
IROS2Message, # TODO: IROS2Message should not be a part of the public API
1724
TopicConfig, # TODO: TopicConfig should not be a part of the public API

src/rai_core/rai/communication/sound_device/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import importlib.util
16+
17+
if importlib.util.find_spec("sounddevice") is None:
18+
raise ImportError(
19+
"This feature is based on sounddevice. Make sure sounddevice is installed."
20+
)
21+
1522
from .api import SoundDeviceAPI, SoundDeviceConfig, SoundDeviceError
1623
from .connector import SoundDeviceConnector, SoundDeviceMessage
1724

src/rai_core/rai/communication/sound_device/api.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any, Callable, Optional
1818

1919
import numpy as np
20+
import sounddevice as sd
2021
from numpy._typing import NDArray
2122
from pydub import AudioSegment
2223
from scipy.signal import resample
@@ -85,10 +86,6 @@ def __init__(self, config: SoundDeviceConfig):
8586
self.device_name = ""
8687

8788
self.config = config
88-
try:
89-
import sounddevice as sd
90-
except ImportError:
91-
raise SoundDeviceError("SoundDeviceAPI requires sound_device module!")
9289
sd.default.latency = ("low", "low") # type: ignore
9390
if config.device_name:
9491
self.device_name = config.device_name
@@ -143,10 +140,7 @@ def write(self, data: AudioSegment, blocking: bool = False, loop: bool = False):
143140
"""
144141
if not self.write_flag:
145142
raise SoundDeviceError(f"{self.device_name} does not support writing!")
146-
try:
147-
import sounddevice as sd
148-
except ImportError:
149-
raise SoundDeviceError("SoundDeviceAPI requires sound_device module!")
143+
150144
audio = np.array(data.get_array_of_samples())
151145
sd.play(
152146
audio,
@@ -187,10 +181,7 @@ def read(self, time: float, blocking: bool = False) -> AudioSegment:
187181

188182
if not self.read_flag:
189183
raise SoundDeviceError(f"{self.device_name} does not support reading!")
190-
try:
191-
import sounddevice as sd
192-
except ImportError:
193-
raise SoundDeviceError("SoundDeviceAPI requires sound_device module!")
184+
194185
frames = int(time * self.sample_rate)
195186
recording = sd.rec(
196187
frames=frames,
@@ -217,10 +208,6 @@ def stop(self):
217208
- This is a convenience function to stop the sound device from playing or recording.
218209
- It will stop any sound that is currently playing and any recording currently happening.
219210
"""
220-
try:
221-
import sounddevice as sd
222-
except ImportError:
223-
raise SoundDeviceError("SoundDeviceAPI requires sound_device module!")
224211
sd.stop()
225212

226213
def wait(self):
@@ -232,10 +219,6 @@ def wait(self):
232219
- This is a convenience function to wait for the sound device to finish playing or recording.
233220
- It will block until the sound is played or recorded.
234221
"""
235-
try:
236-
import sounddevice as sd
237-
except ImportError:
238-
raise SoundDeviceError("SoundDeviceAPI requires sound_device module!")
239222
sd.wait()
240223

241224
def open_write_stream(
@@ -250,10 +233,6 @@ def open_write_stream(
250233
f"{self.device_name} does not support streaming writing!"
251234
)
252235

253-
try:
254-
import sounddevice as sd
255-
except ImportError:
256-
raise SoundDeviceError("SoundDeviceAPI requires sound_device module!")
257236
from sounddevice import CallbackFlags
258237

259238
def callback(indata: NDArray, frames: int, time: Any, status: CallbackFlags):
@@ -323,10 +302,6 @@ def callback(indata: NDArray, frames: int, _, status: CallbackFlags):
323302
}
324303
on_feedback(indata, flag_dict)
325304

326-
try:
327-
import sounddevice as sd
328-
except ImportError:
329-
raise SoundDeviceError("SoundDeviceAPI requires sound_device module!")
330305
try:
331306
if self.config.consumer_sampling_rate is None:
332307
window_size_samples = self.config.block_size * self.sample_rate

src/rai_core/rai/tools/ros2/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import importlib.util
16+
17+
if importlib.util.find_spec("rclpy") is None:
18+
raise ImportError(
19+
"This is a ROS2 feature. Make sure ROS2 is installed and sourced."
20+
)
21+
1522
from .cli import (
1623
ROS2CLIToolkit,
1724
ros2_action,

src/rai_core/rai/tools/ros2/generic/actions.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import importlib.util
16-
17-
if importlib.util.find_spec("rclpy") is None:
18-
raise ImportError(
19-
"This is a ROS2 feature. Make sure ROS2 is installed and sourced."
20-
)
21-
2215
import uuid
2316
from collections import defaultdict
2417
from functools import partial

src/rai_core/rai/tools/ros2/generic/services.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import importlib.util
16-
17-
if importlib.util.find_spec("rclpy") is None:
18-
raise ImportError(
19-
"This is a ROS2 feature. Make sure ROS2 is installed and sourced."
20-
)
21-
2215
from typing import Any, Dict, List, Type
2316

2417
from langchain_core.tools import BaseTool

src/rai_core/rai/tools/ros2/generic/topics.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import importlib.util
16-
17-
if importlib.util.find_spec("rclpy") is None:
18-
raise ImportError(
19-
"This is a ROS2 feature. Make sure ROS2 is installed and sourced."
20-
)
21-
2215
import json
2316
from typing import Any, Dict, List, Literal, Tuple, Type
2417

0 commit comments

Comments
 (0)