Skip to content

Commit 070d515

Browse files
Merge pull request #24 from YannBerthelot/gymnax_unification
changed from t to time to unify with gymnax
2 parents a03a59e + 56b28e6 commit 070d515

27 files changed

+58
-55
lines changed

src/target_gym/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ class EnvParams:
99

1010
@struct.dataclass
1111
class EnvState:
12-
t: int
12+
time: int

src/target_gym/benchmark_speed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ def step_fn(carry, _):
4747

4848
# Warmup
4949
state = step_fn_jit(state)
50-
state.t.block_until_ready()
50+
state.time.block_until_ready()
5151

5252
# Timing
5353
t0 = time.time()
5454
state = step_fn_jit(state)
55-
state.t.block_until_ready()
55+
state.time.block_until_ready()
5656
t1 = time.time()
5757

5858
total_steps = steps * batch_size

src/target_gym/bicycle/env.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
jnp = None
1212

1313
# Import your integrator
14+
from target_gym.base import EnvParams, EnvState
1415
from target_gym.integration import (
1516
integrate_dynamics,
1617
)
@@ -19,7 +20,7 @@
1920

2021

2122
@struct.dataclass
22-
class BikeState:
23+
class BikeState(EnvState):
2324
omega: float # tilt angle [rad]
2425
omega_dot: float # tilt angular velocity
2526
theta: float # steering angle [rad]
@@ -30,7 +31,6 @@ class BikeState:
3031
x_b: float # back wheel x
3132
y_b: float # back wheel y
3233
last_d: float # last displacement action (normalized)
33-
t: int
3434
# for rendering
3535
torque: float = jnp.nan
3636
displacement: float = jnp.nan
@@ -41,7 +41,7 @@ def distance_from_start(self):
4141

4242

4343
@struct.dataclass
44-
class BikeParams:
44+
class BikeParams(EnvParams):
4545
c: float = 0.66
4646
dCM: float = 0.30
4747
h: float = 0.94
@@ -59,7 +59,6 @@ class BikeParams:
5959
delta_t: float = 0.05
6060

6161
max_tilt_deg: float = 12.0
62-
max_steps_in_episode: int = 1_000
6362

6463
use_goal: bool = False
6564
goal_x: float = 0.0
@@ -229,7 +228,7 @@ def vecs_to_state(
229228
x_b=xb_new,
230229
y_b=yb_new,
231230
last_d=metrics.get("d", state.last_d).squeeze(),
232-
t=state.t + 1,
231+
time=state.time + 1,
233232
)
234233

235234

@@ -279,7 +278,7 @@ def _accel_fn(velocities, positions):
279278
def check_is_terminal(state: BikeState, params: BikeParams):
280279
max_tilt_rad = jnp.deg2rad(params.max_tilt_deg)
281280
terminated = jnp.abs(state.omega) > max_tilt_rad
282-
truncated = state.t >= params.max_steps_in_episode
281+
truncated = state.time >= params.max_steps_in_episode
283282
return terminated, truncated
284283

285284

src/target_gym/bicycle/env_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def reset_env(
120120
x_b=zero - params.l,
121121
y_b=zero,
122122
last_d=zero,
123-
t=0,
123+
time=0,
124124
torque=zero,
125125
displacement=zero,
126126
)

src/target_gym/bicycle/rendering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def world_to_screen(x, y):
9191
left_texts = [
9292
f"Lean ω: {np.rad2deg(state.omega):.1f}°",
9393
f"Steer θ: {np.rad2deg(state.theta):.1f}°",
94-
f"Step: {state.t}",
94+
f"Step: {state.time}",
9595
]
9696
right_texts = [
9797
f"Heading ψ: {np.rad2deg(state.psi):.1f}°",

src/target_gym/car/env.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def check_is_terminal(state: EnvState, params: CarParams, xp=jnp):
125125
terminated = jnp.logical_or(
126126
state.velocity <= params.min_velocity, state.velocity >= params.max_velocity
127127
)
128-
truncated = state.t >= params.max_steps_in_episode
128+
truncated = state.time >= params.max_steps_in_episode
129129
return terminated, truncated
130130

131131

@@ -223,7 +223,9 @@ def compute_next_state(
223223
)
224224

225225
return (
226-
state.replace(x=position, velocity=velocity, throttle=throttle, t=state.t + 1),
226+
state.replace(
227+
x=position, velocity=velocity, throttle=throttle, time=state.time + 1
228+
),
227229
metrics,
228230
)
229231

src/target_gym/car/env_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def reset_env(
104104
)
105105

106106
state = CarState(
107-
t=0,
107+
time=0,
108108
x=initial_x,
109109
velocity=initial_velocity,
110110
target_velocity=target_velocity,

src/target_gym/pc_gym/cstr/env.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def compute_next_state(
8787
method=integration_method,
8888
)
8989
return (
90-
state.replace(C_a=C_a, T=T, T_c=T_c_raw, t=state.t + 1),
90+
state.replace(C_a=C_a, T=T, T_c=T_c_raw, time=state.time + 1),
9191
metrics,
9292
)
9393

@@ -106,7 +106,7 @@ def check_is_terminal(state: CSTRState, params: CSTRParams, xp=jnp):
106106
state.C_a <= params.C_a_min, state.C_a >= params.C_a_max
107107
)
108108
terminated = jnp.logical_or(terminated_1, terminated_2)
109-
truncated = state.t >= params.max_steps_in_episode
109+
truncated = state.time >= params.max_steps_in_episode
110110
return terminated, truncated
111111

112112

src/target_gym/pc_gym/cstr/env_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def reset_env(
104104
)
105105

106106
state = CSTRState(
107-
t=0,
107+
time=0,
108108
C_a=initial_C_a,
109109
T=params.initial_T,
110110
target_CA=initial_target_C_a,

src/target_gym/pc_gym/cstr/rendering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _render(cls, screen, state, params, frames, clock, stride: int = 10):
9494
if not hasattr(cls, "history"):
9595
cls.history = {"t": [], "C_a": [], "T": [], "T_c": [], "reward": []}
9696

97-
step = state.t
97+
step = state.time
9898
# Only render every `stride` steps
9999
if step % stride == 0 or step == 1:
100100
frame, cls.history = render_cstr(state, params, step, cls.history)

0 commit comments

Comments
 (0)