Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/target_gym/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ class EnvParams:

@struct.dataclass
class EnvState:
t: int
time: int
4 changes: 2 additions & 2 deletions src/target_gym/benchmark_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ def step_fn(carry, _):

# Warmup
state = step_fn_jit(state)
state.t.block_until_ready()
state.time.block_until_ready()

# Timing
t0 = time.time()
state = step_fn_jit(state)
state.t.block_until_ready()
state.time.block_until_ready()
t1 = time.time()

total_steps = steps * batch_size
Expand Down
11 changes: 5 additions & 6 deletions src/target_gym/bicycle/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
jnp = None

# Import your integrator
from target_gym.base import EnvParams, EnvState
from target_gym.integration import (
integrate_dynamics,
)
Expand All @@ -19,7 +20,7 @@


@struct.dataclass
class BikeState:
class BikeState(EnvState):
omega: float # tilt angle [rad]
omega_dot: float # tilt angular velocity
theta: float # steering angle [rad]
Expand All @@ -30,7 +31,6 @@ class BikeState:
x_b: float # back wheel x
y_b: float # back wheel y
last_d: float # last displacement action (normalized)
t: int
# for rendering
torque: float = jnp.nan
displacement: float = jnp.nan
Expand All @@ -41,7 +41,7 @@ def distance_from_start(self):


@struct.dataclass
class BikeParams:
class BikeParams(EnvParams):
c: float = 0.66
dCM: float = 0.30
h: float = 0.94
Expand All @@ -59,7 +59,6 @@ class BikeParams:
delta_t: float = 0.05

max_tilt_deg: float = 12.0
max_steps_in_episode: int = 1_000

use_goal: bool = False
goal_x: float = 0.0
Expand Down Expand Up @@ -229,7 +228,7 @@ def vecs_to_state(
x_b=xb_new,
y_b=yb_new,
last_d=metrics.get("d", state.last_d).squeeze(),
t=state.t + 1,
time=state.time + 1,
)


Expand Down Expand Up @@ -279,7 +278,7 @@ def _accel_fn(velocities, positions):
def check_is_terminal(state: BikeState, params: BikeParams):
max_tilt_rad = jnp.deg2rad(params.max_tilt_deg)
terminated = jnp.abs(state.omega) > max_tilt_rad
truncated = state.t >= params.max_steps_in_episode
truncated = state.time >= params.max_steps_in_episode
return terminated, truncated


Expand Down
2 changes: 1 addition & 1 deletion src/target_gym/bicycle/env_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def reset_env(
x_b=zero - params.l,
y_b=zero,
last_d=zero,
t=0,
time=0,
torque=zero,
displacement=zero,
)
Expand Down
2 changes: 1 addition & 1 deletion src/target_gym/bicycle/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def world_to_screen(x, y):
left_texts = [
f"Lean ω: {np.rad2deg(state.omega):.1f}°",
f"Steer θ: {np.rad2deg(state.theta):.1f}°",
f"Step: {state.t}",
f"Step: {state.time}",
]
right_texts = [
f"Heading ψ: {np.rad2deg(state.psi):.1f}°",
Expand Down
6 changes: 4 additions & 2 deletions src/target_gym/car/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def check_is_terminal(state: EnvState, params: CarParams, xp=jnp):
terminated = jnp.logical_or(
state.velocity <= params.min_velocity, state.velocity >= params.max_velocity
)
truncated = state.t >= params.max_steps_in_episode
truncated = state.time >= params.max_steps_in_episode
return terminated, truncated


Expand Down Expand Up @@ -223,7 +223,9 @@ def compute_next_state(
)

return (
state.replace(x=position, velocity=velocity, throttle=throttle, t=state.t + 1),
state.replace(
x=position, velocity=velocity, throttle=throttle, time=state.time + 1
),
metrics,
)

Expand Down
2 changes: 1 addition & 1 deletion src/target_gym/car/env_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def reset_env(
)

state = CarState(
t=0,
time=0,
x=initial_x,
velocity=initial_velocity,
target_velocity=target_velocity,
Expand Down
4 changes: 2 additions & 2 deletions src/target_gym/pc_gym/cstr/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def compute_next_state(
method=integration_method,
)
return (
state.replace(C_a=C_a, T=T, T_c=T_c_raw, t=state.t + 1),
state.replace(C_a=C_a, T=T, T_c=T_c_raw, time=state.time + 1),
metrics,
)

Expand All @@ -106,7 +106,7 @@ def check_is_terminal(state: CSTRState, params: CSTRParams, xp=jnp):
state.C_a <= params.C_a_min, state.C_a >= params.C_a_max
)
terminated = jnp.logical_or(terminated_1, terminated_2)
truncated = state.t >= params.max_steps_in_episode
truncated = state.time >= params.max_steps_in_episode
return terminated, truncated


Expand Down
2 changes: 1 addition & 1 deletion src/target_gym/pc_gym/cstr/env_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def reset_env(
)

state = CSTRState(
t=0,
time=0,
C_a=initial_C_a,
T=params.initial_T,
target_CA=initial_target_C_a,
Expand Down
2 changes: 1 addition & 1 deletion src/target_gym/pc_gym/cstr/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _render(cls, screen, state, params, frames, clock, stride: int = 10):
if not hasattr(cls, "history"):
cls.history = {"t": [], "C_a": [], "T": [], "T_c": [], "reward": []}

step = state.t
step = state.time
# Only render every `stride` steps
if step % stride == 0 or step == 1:
frame, cls.history = render_cstr(state, params, step, cls.history)
Expand Down
4 changes: 2 additions & 2 deletions src/target_gym/plane/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def check_mass_does_not_increase(old_mass, new_mass, xp=jnp):
def check_is_terminal(state: PlaneState, params: PlaneParams, xp=jnp):
"""Return True if the episode should terminate."""
terminated = xp.logical_or(state.z <= params.min_alt, state.z >= params.max_alt)
truncated = state.t >= params.max_steps_in_episode
truncated = state.time >= params.max_steps_in_episode

# done = xp.logical_or(done_alt, done_steps)
return terminated, truncated
Expand Down Expand Up @@ -241,7 +241,7 @@ def compute_next_state(
power=power,
stick=stick,
fuel=state.fuel,
t=state.t + 1,
time=state.time + 1,
target_altitude=state.target_altitude,
)
return new_state, metrics
2 changes: 1 addition & 1 deletion src/target_gym/plane/env_gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
# power=initial_power,
# stick=initial_stick,
# fuel=initial_fuel,
# t=0,
# time=0,
# target_altitude=target_altitude,
# )
# return self.get_obs(self.state), self.state
Expand Down
2 changes: 1 addition & 1 deletion src/target_gym/plane/env_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def reset_env(self, key: chex.PRNGKey, params: PlaneParams = None):
power=initial_power,
stick=initial_stick,
fuel=initial_fuel,
t=0,
time=0,
target_altitude=target_altitude,
)

Expand Down
4 changes: 2 additions & 2 deletions src/target_gym/plane/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,11 @@ def rotate_and_translate(points):
text_stick = font.render(
f"Stick: {np.rad2deg(state.stick):.0f}°", True, (0, 0, 255)
)
time_elapsed = time.strftime("%H:%M:%S", time.gmtime(state.t))
time_elapsed = time.strftime("%H:%M:%S", time.gmtime(state.time))
text_time = font.render(f"Time: {time_elapsed}", True, (0, 0, 255))
# reward = 1 if abs(state.z - state.target_altitude) < 1_000 else 0
max_alt_diff = params.max_alt - params.min_alt
done1 = state.t >= params.max_steps_in_episode
done1 = state.time >= params.max_steps_in_episode
if done1:
reward = -1.0 * params.max_steps_in_episode
else:
Expand Down
2 changes: 1 addition & 1 deletion src/target_gym/runners/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def step_fn(carry, _):
obs, new_state, reward, new_done, info = env.step_env(
key, state, action, params
)
truncated = new_state.t >= params.max_steps_in_episode
truncated = new_state.time >= params.max_steps_in_episode
state = jax.lax.cond(done, lambda _: state, lambda _: new_state, operand=None)
done = jnp.logical_or(done, truncated)
value = getattr(new_state, state_attr)
Expand Down
2 changes: 1 addition & 1 deletion src/target_gym/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def save_video(
rewards += reward
if params is None and hasattr(env, "default_params"):
params = env.default_params
truncated = state.t >= params.max_steps_in_episode
truncated = state.time >= params.max_steps_in_episode
done = terminated | truncated

if hasattr(env, "render"):
Expand Down
2 changes: 1 addition & 1 deletion tests/bicycle/test_bicycle_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def default_state(default_params):
x_b=-default_params.l,
y_b=0.0,
last_d=0.0,
t=0,
time=0,
)


Expand Down
2 changes: 1 addition & 1 deletion tests/bicycle/test_bicycle_env_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_reset_returns_obs_and_state(env):
# check state has expected attributes
assert hasattr(state, "omega")
assert hasattr(state, "x_f")
assert state.t == 0
assert state.time == 0


def test_step_returns_correct_tuple(env):
Expand Down
8 changes: 5 additions & 3 deletions tests/car/test_car_env_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,16 @@ def test_get_obs_shape_and_type(env, key):
def test_is_terminal_logic(env):
params = env.default_params
# velocity below min -> terminated
state_low = CarState(x=0.0, velocity=-1.0, t=0, target_velocity=10.0, throttle=0.0)
state_low = CarState(
x=0.0, velocity=-1.0, time=0, target_velocity=10.0, throttle=0.0
)
terminated, truncated = env.is_terminal(state_low, params)
assert terminated
# velocity above max -> terminated
state_high = CarState(
x=0.0,
velocity=params.max_velocity + 1.0,
t=0,
time=0,
target_velocity=10.0,
throttle=0.5,
)
Expand All @@ -101,7 +103,7 @@ def test_is_terminal_logic(env):
state_trunc = CarState(
x=0.0,
velocity=10.0,
t=params.max_steps_in_episode,
time=params.max_steps_in_episode,
target_velocity=10.0,
throttle=0.5,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/car/test_car_physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def state(params):
return CarState(
x=0.0,
velocity=20.0, # ~72 km/h
t=0,
time=0,
throttle=0.0,
target_velocity=25.0,
)
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_compute_acceleration_with_throttle_and_brake(params):

def test_compute_next_state_progress(params, state):
s_next, _ = compute_next_state(1.0, state, params)
assert s_next.t == state.t + 1
assert s_next.time == state.time + 1
assert s_next.x > state.x
assert jnp.isfinite(s_next.velocity)

Expand Down
4 changes: 2 additions & 2 deletions tests/car/test_physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def state(params):
x=0.0,
velocity=20.0, # ~72 km/h
throttle=0.5,
t=0,
time=0,
target_velocity=25.0,
)

Expand Down Expand Up @@ -106,7 +106,7 @@ def test_compute_acceleration_against_drag(params):
def test_compute_next_state_progress(params, state):
s_next, _ = compute_next_state(1.0, state, params)
# Time should advance
assert s_next.t == state.t + 1
assert s_next.time == state.time + 1
# X should advance forward
assert s_next.x > state.x
# Velocity should remain finite
Expand Down
8 changes: 4 additions & 4 deletions tests/pc_gym/cstr/test_cstr_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_compute_velocity_shape_and_trends():
def test_compute_next_state_progression(method):
"""State should evolve and time should increment."""
params = CSTRParams(delta_t=0.1)
state = CSTRState(C_a=1.0, T=350.0, T_c=298.0, t=0, target_CA=0.85)
state = CSTRState(C_a=1.0, T=350.0, T_c=298.0, time=0, target_CA=0.85)

new_state, metrics = compute_next_state(
T_c_raw=0.5,
Expand All @@ -53,7 +53,7 @@ def test_compute_next_state_progression(method):

# State updates
assert isinstance(new_state, CSTRState)
assert new_state.t == state.t + 1
assert new_state.time == state.time + 1
assert not jnp.allclose(new_state.C_a, state.C_a)
assert not jnp.allclose(new_state.T, state.T)

Expand All @@ -71,7 +71,7 @@ def test_compute_next_state_progression(method):
def test_integration_methods_agree_for_small_dt():
"""RK4 and Euler should give similar results if dt is small."""
params = CSTRParams(delta_t=1e-4)
state = CSTRState(C_a=1.0, T=350.0, T_c=298.0, t=0, target_CA=0.85)
state = CSTRState(C_a=1.0, T=350.0, T_c=298.0, time=0, target_CA=0.85)

action = 0.3
new_state_euler, _ = compute_next_state(
Expand All @@ -88,7 +88,7 @@ def test_integration_methods_agree_for_small_dt():
def test_action_clipping():
"""Raw action should be scaled and then clamped between [T_c_min, T_c_max]."""
params = CSTRParams(delta_t=0.1)
state = CSTRState(C_a=1.0, T=350.0, T_c=298.0, t=0, target_CA=0.85)
state = CSTRState(C_a=1.0, T=350.0, T_c=298.0, time=0, target_CA=0.85)

# Very large raw input -> clipped to max
action = 100.0
Expand Down
8 changes: 4 additions & 4 deletions tests/pc_gym/cstr/test_cstr_env_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_reset_env_returns_obs_and_state():
assert isinstance(state, CSTRState)

# Reset initializes time to 0
assert state.t == 0
assert state.time == 0

# Initial C_a is within allowed range
assert (
Expand Down Expand Up @@ -52,7 +52,7 @@ def test_step_env_updates_state_and_obs(method):

# State updated and time advanced
assert isinstance(state2, CSTRState)
assert state2.t == state.t + 1
assert state2.time == state.time + 1

# Reward is finite
assert jnp.isfinite(reward)
Expand Down Expand Up @@ -82,14 +82,14 @@ def test_action_and_observation_space():

def test_is_terminal_propagates_logic():
env = CSTR()
state = CSTRState(t=0, C_a=0.5, T=350.0, target_CA=0.6, T_c=298.0)
state = CSTRState(time=0, C_a=0.5, T=350.0, target_CA=0.6, T_c=298.0)
term, trunc = env.is_terminal(state, env.default_params)
# assert isinstance(result, jnp.ndarray)


def test_get_obs_matches_manual_call():
env = CSTR()
state = CSTRState(t=0, C_a=0.5, T=350.0, target_CA=0.6, T_c=298.0)
state = CSTRState(time=0, C_a=0.5, T=350.0, target_CA=0.6, T_c=298.0)

obs1 = env.get_obs(state)
obs2 = env.get_obs(state, env.default_params)
Expand Down
Loading