Skip to content

Commit 53ff8a9

Browse files
committed
Franka python example
1 parent d035b87 commit 53ff8a9

File tree

1 file changed

+148
-0
lines changed

1 file changed

+148
-0
lines changed

examples/franka.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import numpy as np
2+
from simple_mpc import RobotModelHandler, RobotDataHandler, ArmDynamicsOCP, ArmMPC
3+
import pinocchio as pin
4+
import coal
5+
import example_robot_data as erd
6+
from pinocchio.visualize import MeshcatVisualizer
7+
8+
import simple
9+
10+
# ####### CONFIGURATION ############
11+
# Load robot
12+
robotComplete = erd.load("panda")
13+
rmodelComplete: pin.Model = robotComplete.model
14+
15+
# Reduce model
16+
qComplete = rmodelComplete.referenceConfigurations["default"]
17+
locked_joints = [8, 9]
18+
19+
robot = robotComplete.buildReducedRobot(locked_joints, qComplete)
20+
rmodel: pin.Model = robot.model
21+
rdata = rmodel.createData()
22+
23+
geom_model = robot.collision_model
24+
q0 = rmodel.referenceConfigurations["default"]
25+
tool_name = "panda_leftfinger"
26+
tool_id = rmodel.getFrameId(tool_name)
27+
visual_model = robot.visual_model
28+
29+
# Create the simulator object
30+
simulator = simple.Simulator(rmodel, geom_model)
31+
32+
# Create Model and Data handler
33+
model_handler = RobotModelHandler(rmodel, "default", "panda_link0")
34+
data_handler = RobotDataHandler(model_handler)
35+
36+
# Create OCP
37+
w_q = np.ones(7) * 1
38+
w_v = np.ones(7) * 1
39+
w_x = np.concatenate((w_q, w_v))
40+
w_u = np.ones(7) * 1e-2
41+
w_frame = np.diag(np.array([1000, 1000, 1000]))
42+
43+
dt = 0.01
44+
dt_sim = 1e-3
45+
problem_conf = dict(
46+
timestep=dt,
47+
w_x=np.diag(w_x),
48+
w_u=np.diag(w_u),
49+
gravity=np.array([0, 0, -9.81]),
50+
w_frame=w_frame,
51+
umin=-model_handler.getModel().effortLimit,
52+
umax=model_handler.getModel().effortLimit,
53+
qmin=model_handler.getModel().lowerPositionLimit,
54+
qmax=model_handler.getModel().upperPositionLimit,
55+
torque_limits=True,
56+
kinematics_limits=True,
57+
ee_name=tool_name,
58+
)
59+
T = 100
60+
61+
dynproblem = ArmDynamicsOCP(problem_conf, model_handler)
62+
dynproblem.createProblem(model_handler.getReferenceState(), T)
63+
64+
# Create MPC
65+
N_simu = int(dt / dt_sim)
66+
mpc_conf = dict(
67+
TOL=1e-4,
68+
mu_init=1e-8,
69+
max_iters=1,
70+
num_threads=1,
71+
timestep=dt,
72+
)
73+
74+
mpc = ArmMPC(mpc_conf, dynproblem)
75+
76+
target_pos = np.array([0.15, 0.5, 0.5])
77+
78+
nv = mpc.getModelHandler().getModel().nv
79+
nx = nv * 2
80+
81+
q = q0.copy()
82+
v = np.zeros(nv)
83+
x_measured = np.concatenate([q, v])
84+
mpc.getDataHandler().updateInternalData(x_measured, False)
85+
86+
# Visualization of target
87+
fr_name = "universe"
88+
fr_id = rmodel.getFrameId(fr_name)
89+
joint_id = rmodel.frames[fr_id].parentJoint
90+
target_place = pin.SE3.Identity()
91+
target_place.translation = target_pos
92+
target_object1 = pin.GeometryObject(
93+
"target1", fr_id, joint_id, coal.Sphere(0.02), target_place
94+
)
95+
target_object1.meshColor[:] = [0.5, 0.5, 1.0, 1.0]
96+
visual_model.addGeometryObject(target_object1)
97+
target_id1 = visual_model.getGeometryId("target1")
98+
visual_data = visual_model.createData()
99+
100+
z_mov = 0.2
101+
x_mov = 0.2
102+
freq_mov = 1
103+
104+
### Load visualizer
105+
vizer = MeshcatVisualizer(rmodel, geom_model, visual_model, data=rdata)
106+
vizer.initViewer(open=True, loadModel=True)
107+
vizer.display(pin.neutral(rmodel))
108+
vizer.setBackgroundColor()
109+
110+
vizer.display(q)
111+
112+
Tmpc = 1000
113+
target_new = target_pos.copy()
114+
115+
print("Start simu")
116+
for t in range(1000):
117+
if t == 300:
118+
# Stop tracking target
119+
print("Switch to rest")
120+
mpc.switchToRest()
121+
if t > 600 or t < 300:
122+
# Track target
123+
mpc.switchToReach(target_new)
124+
print("Time " + str(t))
125+
for j in range(N_simu):
126+
u_interp = (N_simu - j) / N_simu * mpc.us[0] + j / N_simu * mpc.us[1]
127+
x_interp = (N_simu - j) / N_simu * mpc.xs[0] + j / N_simu * mpc.xs[1]
128+
129+
mpc.getDataHandler().updateInternalData(x_measured, True)
130+
131+
current_torque = u_interp - 1.0 * mpc.Ks[0] @ model_handler.difference(
132+
x_measured, x_interp
133+
)
134+
135+
simulator.reset()
136+
simulator.step(q, v, current_torque, dt_sim)
137+
138+
q, v = simulator.state.qnew.copy(), simulator.state.vnew.copy()
139+
x_measured = np.concatenate((q, v))
140+
141+
# Change current target in MPC
142+
target_new[0] = target_pos[0] + x_mov * np.sin(np.pi * t * freq_mov / 180)
143+
target_new[2] = target_pos[2] + z_mov * np.cos(np.pi * t * freq_mov / 180)
144+
145+
vizer.visual_model.geometryObjects[target_id1].placement.translation = target_new
146+
147+
mpc.iterate(x_measured)
148+
vizer.display(q)

0 commit comments

Comments
 (0)