Skip to content

Commit 086db12

Browse files
authored
feat: unify pydantic models for ros2 types (#545)
1 parent b62fee1 commit 086db12

32 files changed

+825
-582
lines changed

src/rai_bench/rai_bench/manipulation_o3de/benchmark.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,7 @@
2222
from rclpy.impl.rcutils_logger import RcutilsLogger
2323

2424
from rai_bench.manipulation_o3de.interfaces import Task
25-
from rai_sim.simulation_bridge import (
26-
Entity,
27-
SimulationBridge,
28-
SimulationConfigT,
29-
)
25+
from rai_sim.simulation_bridge import Entity, SimulationBridge, SimulationConfigT
3026

3127
loggers_type = Union[RcutilsLogger, logging.Logger]
3228
EntityT = TypeVar("EntityT", bound=Entity)

src/rai_bench/rai_bench/manipulation_o3de/interfaces.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
from collections import defaultdict
1818
from typing import Dict, List, Set, Tuple, TypeVar, Union
1919

20+
from rai.types import Pose
2021
from rclpy.impl.rcutils_logger import RcutilsLogger
2122

2223
from rai_sim.simulation_bridge import (
2324
Entity,
24-
Pose,
2525
SimulationBridge,
2626
SimulationConfig,
2727
SimulationConfigT,
28+
SpawnedEntity,
2829
)
2930

3031
loggers_type = Union[RcutilsLogger, logging.Logger]
@@ -114,9 +115,9 @@ def filter_entities_by_object_type(
114115
def euclidean_distance(self, pos1: Pose, pos2: Pose) -> float:
115116
"""Calculate euclidean distance between 2 positions"""
116117
return (
117-
(pos1.translation.x - pos2.translation.x) ** 2
118-
+ (pos1.translation.y - pos2.translation.y) ** 2
119-
+ (pos1.translation.z - pos2.translation.z) ** 2
118+
(pos1.position.x - pos2.position.x) ** 2
119+
+ (pos1.position.y - pos2.position.y) ** 2
120+
+ (pos1.position.z - pos2.position.z) ** 2
120121
) ** 0.5
121122

122123
def is_adjacent(self, pos1: Pose, pos2: Pose, threshold_distance: float):
@@ -217,7 +218,9 @@ def build_neighbourhood_list(
217218
other
218219
for other in entities
219220
if entity != other
220-
and self.is_adjacent(entity.pose, other.pose, threshold_distance)
221+
and self.is_adjacent(
222+
entity.pose.pose, other.pose.pose, threshold_distance
223+
)
221224
]
222225
return neighbourhood_graph
223226

@@ -335,15 +338,16 @@ def group_entities_along_z_axis(
335338
"""
336339

337340
entities = sorted(
338-
entities, key=lambda ent: (ent.pose.translation.x, ent.pose.translation.y)
341+
entities,
342+
key=lambda ent: (ent.pose.pose.position.x, ent.pose.pose.position.y),
339343
)
340344

341345
groups: List[List[EntityT]] = []
342346
for entity in entities:
343347
placed = False
344348
for group in groups:
345-
dx = group[0].pose.translation.x - entity.pose.translation.x
346-
dy = group[0].pose.translation.y - entity.pose.translation.y
349+
dx = group[0].pose.pose.position.x - entity.pose.pose.position.x
350+
dy = group[0].pose.pose.position.y - entity.pose.pose.position.y
347351
if math.sqrt(dx * dx + dy * dy) <= margin:
348352
group.append(entity)
349353
placed = True
@@ -424,12 +428,14 @@ def validate_config(self, simulation_config: SimulationConfig) -> bool:
424428
return False
425429

426430
@abstractmethod
427-
def calculate_correct(self, entities: List[EntityT]) -> Tuple[int, int]:
431+
def calculate_correct(
432+
self, entities: List[Entity] | List[SpawnedEntity]
433+
) -> Tuple[int, int]:
428434
"""Method to calculate how many objects are placed correctly
429435
430436
Parameters
431437
----------
432-
entities : List[EntityT]
438+
entities : List[Entity]
433439
list of ALL entities present in the simulaiton scene
434440
435441
Returns

src/rai_bench/rai_bench/manipulation_o3de/tasks/move_object_to_left_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]:
7373
entities=entities, object_types=self.obj_types
7474
)
7575
correct = sum(
76-
1 for ent in selected_type_objects if ent.pose.translation.y > 0.0
76+
1 for ent in selected_type_objects if ent.pose.pose.position.y > 0.0
7777
)
7878
incorrect: int = len(selected_type_objects) - correct
7979
return correct, incorrect

src/rai_bench/rai_bench/manipulation_o3de/tasks/place_at_coord_task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]:
9090
correct = 0
9191

9292
for ent in target_objects:
93-
dx = ent.pose.translation.x - self.target_position[0]
94-
dy = ent.pose.translation.y - self.target_position[1]
93+
dx = ent.pose.pose.position.x - self.target_position[0]
94+
dy = ent.pose.pose.position.y - self.target_position[1]
9595
distance = math.sqrt(dx**2 + dy**2)
9696
if distance <= self.allowable_displacement:
9797
correct = 1 # Only one correct placement is needed.

src/rai_bench/rai_bench/manipulation_o3de/tasks/place_cubes_task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]:
8989
1
9090
for ent in entities
9191
if self.is_adjacent_to_any(
92-
ent.pose,
93-
[e.pose for e in entities if e != ent],
92+
ent.pose.pose,
93+
[e.pose.pose for e in entities if e != ent],
9494
self.threshold_distance,
9595
)
9696
)

src/rai_bench/rai_bench/manipulation_o3de/tasks/rotate_object_task.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
import math
1717
from typing import List, Tuple, Union
1818

19+
from rai.types import Quaternion
1920
from rclpy.impl.rcutils_logger import RcutilsLogger
2021

2122
from rai_bench.manipulation_o3de.interfaces import (
2223
ManipulationTask,
2324
)
24-
from rai_sim.simulation_bridge import Entity, Rotation, SimulationConfig
25+
from rai_sim.simulation_bridge import Entity, SimulationConfig
2526

2627
loggers_type = Union[RcutilsLogger, logging.Logger]
2728

@@ -30,7 +31,7 @@ class RotateObjectTask(ManipulationTask):
3031
def __init__(
3132
self,
3233
obj_types: List[str],
33-
target_quaternion: Rotation,
34+
target_quaternion: Quaternion,
3435
logger: loggers_type | None = None,
3536
):
3637
# NOTE (jmatejcz) for now manipulaiton tool does not support passing rotation
@@ -103,14 +104,14 @@ def calculate_correct(
103104
incorrect = 0
104105
for entity in entities:
105106
if entity.prefab_name in self.obj_types:
106-
if not entity.pose.rotation:
107+
if not entity.pose.pose.orientation:
107108
ValueError("Entity has no rotation defined.")
108109
else:
109110
dot = (
110-
entity.pose.rotation.x * self.target_quaternion.x
111-
+ entity.pose.rotation.y * self.target_quaternion.y
112-
+ entity.pose.rotation.z * self.target_quaternion.z
113-
+ entity.pose.rotation.w * self.target_quaternion.w
111+
entity.pose.pose.orientation.x * self.target_quaternion.x
112+
+ entity.pose.pose.orientation.y * self.target_quaternion.y
113+
+ entity.pose.pose.orientation.z * self.target_quaternion.z
114+
+ entity.pose.pose.orientation.w * self.target_quaternion.w
114115
)
115116
# Account for the double cover: q and -q represent the same rotation.
116117
dot = abs(dot)

src/rai_bench/rai_bench/tool_calling_agent/messages/actions.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,66 +14,66 @@
1414

1515
from typing import Any, Dict, Optional
1616

17-
from rai_bench.tool_calling_agent.messages.base import PoseStamped, Ros2BaseModel, Time
17+
from rai.types import PoseStamped, ROS2BaseModel, Time
1818

1919

20-
class TaskGoal(Ros2BaseModel):
20+
class TaskGoal(ROS2BaseModel):
2121
task: Optional[str] = ""
2222
description: Optional[str] = ""
2323
priority: Optional[str] = ""
2424

2525

26-
class TaskResult(Ros2BaseModel):
26+
class TaskResult(ROS2BaseModel):
2727
success: Optional[bool] = False
2828
report: Optional[str] = ""
2929

3030

31-
class TaskFeedback(Ros2BaseModel):
31+
class TaskFeedback(ROS2BaseModel):
3232
current_status: Optional[str] = ""
3333

3434

35-
class LoadMapRequest(Ros2BaseModel):
35+
class LoadMapRequest(ROS2BaseModel):
3636
filename: Optional[str] = ""
3737

3838

39-
class LoadMapResponse(Ros2BaseModel):
39+
class LoadMapResponse(ROS2BaseModel):
4040
success: Optional[bool] = False
4141

4242

43-
class NavigateToPoseGoal(Ros2BaseModel):
43+
class NavigateToPoseGoal(ROS2BaseModel):
4444
pose: Optional[PoseStamped] = None
4545
behavior_tree: Optional[str] = None
4646

4747

48-
class ActionResult(Ros2BaseModel):
48+
class ActionResult(ROS2BaseModel):
4949
result: Optional[Dict[str, Any]] = None
5050

5151

52-
class NavigateToPoseFeedback(Ros2BaseModel):
52+
class NavigateToPoseFeedback(ROS2BaseModel):
5353
current_pose: Optional[PoseStamped] = None
5454
navigation_time: Optional[Time] = None
5555
estimated_time_remaining: Optional[Time] = None
5656
number_of_recoveries: Optional[int] = None
5757
distance_remaining: Optional[float] = None
5858

5959

60-
class NavigateToPoseAction(Ros2BaseModel):
60+
class NavigateToPoseAction(ROS2BaseModel):
6161
goal: Optional[NavigateToPoseGoal] = None
6262
result: Optional[ActionResult] = None
6363
feedback: Optional[NavigateToPoseFeedback] = None
6464

6565

66-
class SpinGoal(Ros2BaseModel):
66+
class SpinGoal(ROS2BaseModel):
6767
target_yaw: Optional[float] = None
6868
time_allowance: Optional[Time] = None
6969

7070

71-
class SpinFeedback(Ros2BaseModel):
71+
class SpinFeedback(ROS2BaseModel):
7272
angle_turned: Optional[float] = None
7373
remaining_yaw: Optional[float] = None
7474

7575

76-
class SpinAction(Ros2BaseModel):
76+
class SpinAction(ROS2BaseModel):
7777
goal: Optional[SpinGoal] = None
7878
result: Optional[ActionResult] = None
7979
feedback: Optional[SpinFeedback] = None

src/rai_bench/rai_bench/tool_calling_agent/messages/base.py

Lines changed: 4 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -12,93 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Optional
15+
from rai.types import ROS2BaseModel, Time
1616

17-
from pydantic import BaseModel, ConfigDict
1817

19-
20-
class Ros2BaseModel(BaseModel):
21-
model_config = ConfigDict(extra="forbid")
22-
23-
24-
class Time(Ros2BaseModel):
25-
sec: Optional[int] = 0
26-
nanosec: Optional[int] = 0
27-
28-
29-
class Header(Ros2BaseModel):
30-
stamp: Optional[Time] = Time()
31-
frame_id: Optional[str] = ""
32-
33-
34-
class RegionOfInterest(Ros2BaseModel):
35-
x_offset: Optional[int] = 0
36-
y_offset: Optional[int] = 0
37-
height: Optional[int] = 0
38-
width: Optional[int] = 0
39-
do_rectify: Optional[bool] = False
40-
41-
42-
class Position(Ros2BaseModel):
43-
x: Optional[float] = 0.0
44-
y: Optional[float] = 0.0
45-
z: Optional[float] = 0.0
46-
47-
48-
class Orientation(Ros2BaseModel):
49-
x: Optional[float] = 0.0
50-
y: Optional[float] = 0.0
51-
z: Optional[float] = 0.0
52-
w: Optional[float] = 1.0
53-
54-
55-
class Pose(Ros2BaseModel):
56-
position: Optional[Position] = Position()
57-
orientation: Optional[Orientation] = Orientation()
58-
59-
60-
class PoseStamped(Ros2BaseModel):
61-
header: Optional[Header] = Header()
62-
pose: Optional[Pose] = Pose()
63-
64-
65-
class Clock(Ros2BaseModel):
66-
clock: Optional[Time] = Time()
67-
68-
69-
class ObjectHypothesis(Ros2BaseModel):
70-
class_id: Optional[str] = ""
71-
score: Optional[float] = 0.0
72-
73-
74-
class PoseWithCovariance(Ros2BaseModel):
75-
pose: Optional[Pose] = Pose()
76-
covariance: Optional[List[float]] = [0.0] * 36
77-
78-
79-
class ObjectHypothesisWithPose(Ros2BaseModel):
80-
hypothesis: Optional[ObjectHypothesis] = ObjectHypothesis()
81-
pose: Optional[PoseWithCovariance] = PoseWithCovariance()
82-
83-
84-
class Point2D(Ros2BaseModel):
85-
x: Optional[float] = 0.0
86-
y: Optional[float] = 0.0
87-
88-
89-
class Pose2D(Ros2BaseModel):
90-
position: Optional[Point2D] = Point2D()
91-
theta: Optional[float] = 0.0
92-
93-
94-
class BoundingBox2D(Ros2BaseModel):
95-
center: Optional[Pose2D] = Pose2D()
96-
size_x: Optional[float] = 0.0
97-
size_y: Optional[float] = 0.0
98-
99-
100-
class Detection2D(Ros2BaseModel):
101-
header: Optional[Header] = Header()
102-
results: Optional[List[ObjectHypothesisWithPose]] = []
103-
bbox: Optional[BoundingBox2D] = BoundingBox2D()
104-
id: Optional[str] = ""
18+
class Clock(ROS2BaseModel):
19+
_prefix: str = "rosgraph_msgs/msg"
20+
clock: Time = Time()

0 commit comments

Comments
 (0)