Skip to content

Commit 666fd06

Browse files
committed
Update test_inverse_kinematics.py
1 parent 1582d67 commit 666fd06

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

tests/test_inverse_kinematics.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,27 @@ def make_transparent(link):
2424
for link in visual_data:
2525
make_transparent(link)
2626

27-
28-
def test_jacobian_follower():
27+
def create_test_chain(robot="kuka_iiwa", device="cpu"):
28+
if robot == "kuka_iiwa":
29+
urdf = "kuka_iiwa/model.urdf"
30+
search_path = pybullet_data.getDataPath()
31+
full_urdf = os.path.join(search_path, urdf)
32+
chain = pk.build_serial_chain_from_urdf(open(full_urdf).read(), "lbr_iiwa_link_7")
33+
chain = chain.to(device=device)
34+
elif robot == "widowx":
35+
urdf = "widowx/wx250s.urdf"
36+
full_urdf = urdf
37+
chain = pk.build_serial_chain_from_urdf(open(full_urdf, "rb").read(), "ee_gripper_link")
38+
chain = chain.to(device=device)
39+
else:
40+
raise NotImplementedError(f"Robot {robot} not implemented")
41+
return chain, urdf
42+
43+
def test_jacobian_follower(robot="kuka_iiwa"):
2944
pytorch_seed.seed(2)
3045
device = "cuda" if torch.cuda.is_available() else "cpu"
31-
# device = "cpu"
32-
urdf = "kuka_iiwa/model.urdf"
3346
search_path = pybullet_data.getDataPath()
34-
full_urdf = os.path.join(search_path, urdf)
35-
chain = pk.build_serial_chain_from_urdf(open(full_urdf).read(), "lbr_iiwa_link_7")
36-
chain = chain.to(device=device)
47+
chain, urdf = create_test_chain(robot=robot, device=device)
3748

3849
# robot frame
3950
pos = torch.tensor([0.0, 0.0, 0.0], device=device)
@@ -45,7 +56,7 @@ def test_jacobian_follower():
4556
# generate random goal joint angles (so these are all achievable)
4657
# use the joint limits to generate random joint angles
4758
lim = torch.tensor(chain.get_joint_limits(), device=device)
48-
goal_q = torch.rand(M, 7, device=device) * (lim[1] - lim[0]) + lim[0]
59+
goal_q = torch.rand(M, lim.shape[1], device=device) * (lim[1] - lim[0]) + lim[0]
4960

5061
# get ee pose (in robot frame)
5162
goal_in_rob_frame_tf = chain.forward_kinematics(goal_q)
@@ -209,5 +220,6 @@ def test_ik_in_place_no_err():
209220

210221

211222
if __name__ == "__main__":
212-
test_jacobian_follower()
223+
# test_jacobian_follower(robot="kuka_iiwa")
224+
test_jacobian_follower(robot="widowx")
213225
test_ik_in_place_no_err()

0 commit comments

Comments
 (0)