Skip to content

Commit da65bc5

Browse files
committed
more flexible config
1 parent a7d3749 commit da65bc5

File tree

1 file changed

+64
-40
lines changed

1 file changed

+64
-40
lines changed

proj/model/config.py

Lines changed: 64 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,59 @@
11
import numpy as np
22

3+
# ------------------------------- Model specific params ------------------------------ #
4+
_cart_params = dict(
5+
STATE_SIZE=5,
6+
INPUT_SIZE=2,
7+
ANGLE_IDX=2, # state vector index which is angle, used to fit diff in
8+
R=np.diag([0.05, 0.05]), # control cost
9+
Q=np.diag([1, 1, 1, 1, 0]), # state cost | x, y, theta, v, omega
10+
Sf=np.diag([0, 0, 0, 0, 0]), # final state cost
11+
)
12+
13+
_polar_params = dict(
14+
STATE_SIZE=4,
15+
INPUT_SIZE=2,
16+
ANGLE_IDX=2, # state vector index which is angle, used to fit diff in
17+
R=np.diag([0.05, 0.05]), # control cost
18+
Q=np.diag([2.5, 2.5, 0, 0]), # state cost | r, omega, v, omega
19+
Sf=np.diag([2.5, 2.5, 0, 0]), # final state cost
20+
)
21+
22+
# -------------------------------- Mouse types ------------------------------- #
23+
24+
_easy_mouse = dict(
25+
L=1.5, # half body width | cm
26+
R=1, # radius of wheels | cm
27+
d=0.1, # distance between axel and CoM | cm
28+
length=3, # cm
29+
m=round(20 / 9.81, 2), # mass | g
30+
m_w=round(2 / 9.81, 2), # mass of wheels/legs |g
31+
mouse_type="easy",
32+
)
33+
34+
_realistic_mouse = dict(
35+
L=2, # half body width | cm
36+
R=2, # radius of wheels | cm
37+
d=3, # distance between axel and CoM | cm
38+
length=8.6, # cm
39+
m=round(25 / 9.81, 2), # mass | g
40+
m_w=round(0.6 / 9.81, 2), # mass of wheels/legs |g
41+
mouse_type="realistic",
42+
)
43+
344

445
class Config:
46+
# ----------------------------- Simulation params ---------------------------- #
547
SIMULATION_NAME = ""
648

749
USE_FAST = True # if true use cumba's methods
850
SPAWN_TYPE = "trajectory"
951
LIVE_PLOT = True
1052

11-
# ----------------------------- Simulation params ---------------------------- #
12-
dt = 0.005
13-
14-
# -------------------------------- Cost params ------------------------------- #
15-
STATE_SIZE = 5
16-
INPUT_SIZE = 2
17-
ANGLE_IDX = 2 # state vector index which is angle, used to fit diff in
18-
19-
R = np.diag([0.05, 0.05]) # control cost
20-
Q = np.diag([1, 1, 1, 1, 0]) # state cost | x, y, theta, v, omega
21-
Sf = np.diag([0, 0, 0, 0, 0]) # final state cost
22-
23-
# STATE_SIZE = 4
24-
# INPUT_SIZE = 2
25-
26-
# R = np.diag([0.05, 0.05]) # control cost
27-
# Q = np.diag([2.5, 2.5, 0, 0]) # state cost | r, omega, v, omega
28-
# Sf = np.diag([2.5, 2.5, 0, 0]) # final state cost
29-
30-
# ------------------------------- Mouse params ------------------------------- #
31-
# ? works
32-
mouse = dict(
33-
mouse_type="working",
34-
L=1.5, # half body width | cm
35-
R=1, # radius of wheels | cm
36-
d=0.1, # distance between axel and CoM | cm
37-
length=3, # cm
38-
m=round(20 / 9.81, 2), # mass | g
39-
m_w=round(2 / 9.81, 2), # mass of wheels/legs |g
40-
)
53+
mouse_type = "easy"
54+
model_type = "cart"
4155

42-
# ? more realistic
43-
# mouse = dict(
44-
# mouse_type = 'realistic',
45-
# L=2, # half body width | cm
46-
# R=2, # radius of wheels | cm
47-
# d=3, # distance between axel and CoM | cm
48-
# length=8.6, # cm
49-
# m=round(25 / 9.81, 2), # mass | g
50-
# m_w=round(0.6 / 9.81, 2), # mass of wheels/legs |g
51-
# )
56+
dt = 0.005
5257

5358
# ------------------------------ Goal trajectory ----------------------------- #
5459

@@ -84,6 +89,25 @@ class Config:
8489
threshold=1e-6,
8590
)
8691

92+
def __init__(self,):
93+
# get mouse params
94+
self.mouse = (
95+
_easy_mouse if self.mouse_type == "easy" else _realistic_mouse
96+
)
97+
98+
# set model params
99+
if self.model_type == "cart":
100+
params = _cart_params
101+
else:
102+
params = _polar_params
103+
104+
self.STATE_SIZE = params["STATE_SIZE"]
105+
self.INPUT_SIZE = params["INPUT_SIZE"]
106+
self.ANGLE_IDX = params["ANGLE_IDX"]
107+
self.R = params["R"]
108+
self.Q = params["Q"]
109+
self.Sf = params["Sf"]
110+
87111
def config_dict(self):
88112
return dict(
89113
dt=self.dt,

0 commit comments

Comments
 (0)