Skip to content

Commit c7cda5e

Browse files
committed
Pretty tree print for frames
1 parent 0d67ac2 commit c7cda5e

File tree

3 files changed

+50
-15
lines changed

3 files changed

+50
-15
lines changed

src/pytorch_kinematics/chain.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,12 @@ def get_link_names(self):
280280
def get_frame_indices(self, *frame_names):
281281
return torch.tensor([self.frame_to_idx[n] for n in frame_names], dtype=torch.long, device=self.device)
282282

283+
def print_link_tree(self, do_print=True):
284+
tree = str(self._root)
285+
if do_print:
286+
print(tree)
287+
return tree
288+
283289
def forward_kinematics(self, th, frame_indices: Optional = None):
284290
"""
285291
Compute forward kinematics for the given joint values.

src/pytorch_kinematics/frame.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ def __repr__(self):
8585
self.axis)
8686

8787

88+
# prefix components:
89+
space = ' '
90+
branch = '│ '
91+
# pointers:
92+
tee = '├── '
93+
last = '└── '
94+
8895
class Frame(object):
8996
def __init__(self, name=None, link=None, joint=None, children=None):
9097
self.name = 'None' if name is None else name
@@ -93,10 +100,18 @@ def __init__(self, name=None, link=None, joint=None, children=None):
93100
if children is None:
94101
self.children = []
95102

96-
def __str__(self, level=0):
97-
ret = " \t" * level + self.name + "\n"
98-
for child in self.children:
99-
ret += child.__str__(level + 1)
103+
def __str__(self, prefix='', root=True):
104+
pointers = [tee] * (len(self.children) - 1) + [last]
105+
if root:
106+
ret = prefix + self.name + "\n"
107+
else:
108+
ret = ""
109+
for pointer, child in zip(pointers, self.children):
110+
ret += prefix + pointer + child.name + "\n"
111+
if child.children:
112+
extension = branch if pointer == tee else space
113+
# i.e. space because last, └── , above so no more |
114+
ret += child.__str__(prefix=prefix + extension, root=False)
100115
return ret
101116

102117
def to(self, *args, **kwargs):

tests/test_inverse_kinematics.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -213,20 +213,34 @@ def test_extract_serial_chain_from_tree():
213213
# device = "cpu"
214214
urdf = "widowx/wx250s.urdf"
215215
full_urdf = os.path.join(TEST_DIR, urdf)
216-
chain = pk.build_serial_chain_from_urdf(open(full_urdf, mode="rb").read(), "ee_gripper_link")
217-
# chain = pk.SerialChain(chain, "ee_gripper_link", "base_link")
216+
chain = pk.build_chain_from_urdf(open(full_urdf, mode="rb").read())
217+
# full frames
218+
full_frame_expected = """
219+
base_link
220+
└── shoulder_link
221+
└── upper_arm_link
222+
└── upper_forearm_link
223+
└── lower_forearm_link
224+
└── wrist_link
225+
└── gripper_link
226+
└── ee_arm_link
227+
├── gripper_prop_link
228+
└── gripper_bar_link
229+
└── fingers_link
230+
├── left_finger_link
231+
├── right_finger_link
232+
└── ee_gripper_link
233+
"""
234+
full_frame = chain.print_link_tree()
235+
assert full_frame_expected.strip() == full_frame.strip()
236+
237+
chain = pk.SerialChain(chain, "ee_gripper_link", "base_link")
238+
serial_frame = chain.print_link_tree()
218239
chain = chain.to(device=device)
219240

220241
# full chain should have DOF = 8, however since we are creating just a serial chain to ee_gripper_link, should be 6
221-
# TODO pretty tree print
222-
"""
223-
/
224-
└── gripper_bar_link
225-
└── fingers_link
226-
├── left_finger_link
227-
├── right_finger_link
228-
└── ee_gripper_link
229-
"""
242+
dof = len(chain.get_joints(exclude_fixed=True))
243+
assert dof == 6
230244

231245
# robot frame
232246
pos = torch.tensor([0.0, 0.0, 0.0], device=device)

0 commit comments

Comments
 (0)