Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ As well as environments adapted from [Process-Control Gym](https://github.com/Ma
* 🧪 **CSTR** - control of a chemical reaction in a continuous stirred-tank reactor (CSTR).
* **More to come**

| Environment | Target Type | Action Dim | State Dim | Steps / Second (1 env - $10^8$ steps) |
|-------------|---------------------|------------|-----------|-----------------|
| 🛩 Plane | Stable-Target-MDP | 2 (power, stick) | 9 (z, ż, ẋ, θ, θ̇, …) | ~0.54M (CPU) |
| 🚗 Car | Stable-Target-MDP | 1 (throttle/brake) | 12 (velocity, lidar, …) | ~0.85M |
| 🚲 Bike | Unstable-Target-MDP | 2 (torque, displacement) | 5 (angle, angular vel, pos, …) | ~1.77M |
| 🧪 CSTR | Stable-Target-MDP | 1 (coolant temperature) | 3 (concentration, temperature, target temperature) | ~1.49M |

<table align="center">
<tr>
<td align="center">
Expand Down
105 changes: 105 additions & 0 deletions src/target_gym/benchmark_speed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import time

import jax
import jax.numpy as jnp


def benchmark_env(env, params, steps: int = 1_000_000, batch_size: int = 1024):
"""
Benchmark how many steps per second an environment can run in parallel.

Args:
env: JAX environment (must implement reset_env, step)
params: environment parameters
steps: number of timesteps to simulate
batch_size: number of parallel environments

Returns:
steps_per_second: float
"""
key = jax.random.PRNGKey(0)

# Reset batch of environments
keys = jax.random.split(key, batch_size)
obs, state = jax.vmap(env.reset_env, in_axes=(0, None))(keys, params)

# Dummy action (zeros)
action_dim = env.action_space(params).shape[0]
# actions = (
# jnp.zeros((batch_size, action_dim))
# if action_dim > 1
# else jnp.zeros((batch_size,))
# )
actions = jnp.zeros((batch_size, action_dim))

def step_fn(carry, _):
state = carry
key = jax.random.PRNGKey(0) # reuse same rng (no stochasticity assumed)
obs, new_state, reward, done, info = jax.vmap(
env.step, in_axes=(None, 0, 0, None)
)(key, state, actions, params)
return new_state, None

# JIT compile scan loop
step_fn_jit = jax.jit(
lambda s: jax.lax.fori_loop(0, steps, lambda _, st: step_fn(st, None)[0], s)
)

# Warmup
state = step_fn_jit(state)
state.t.block_until_ready()

# Timing
t0 = time.time()
state = step_fn_jit(state)
state.t.block_until_ready()
t1 = time.time()

total_steps = steps * batch_size
steps_per_second = total_steps / (t1 - t0)

return steps_per_second


if __name__ == "__main__":
from target_gym import (
CSTR,
Bike,
BikeParams,
Car,
CarParams,
CSTRParams,
Plane,
PlaneParams,
)

N_steps = int(1e8)
max_steps_in_episode = 10_000

plane_env = Plane()
plane_params = PlaneParams(max_steps_in_episode=max_steps_in_episode)
car_env = Car()
car_params = CarParams(max_steps_in_episode=max_steps_in_episode)

bike_env = Bike()
bike_params = BikeParams(max_steps_in_episode=max_steps_in_episode)

cstr_env = CSTR()
cstr_params = CSTRParams(max_steps_in_episode=max_steps_in_episode)

print(
"Plane M-steps/sec:",
benchmark_env(plane_env, plane_params, steps=N_steps, batch_size=1) / int(1e6),
)
print(
"Car M-steps/sec:",
benchmark_env(car_env, car_params, steps=N_steps, batch_size=1) / int(1e6),
)
print(
"Bike M-steps/sec:",
benchmark_env(bike_env, bike_params, steps=N_steps, batch_size=1) / int(1e6),
)
print(
"CSTR M-steps/sec:",
benchmark_env(cstr_env, cstr_params, steps=N_steps, batch_size=1) / int(1e6),
)
21 changes: 8 additions & 13 deletions src/target_gym/car/env.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
from typing import Callable, Optional, Tuple

import chex
import jax
import jax.numpy as jnp

# Optional: jax imports (only used in jax env)
import numpy as np
from flax import struct
from jax.tree_util import Partial as partial

try:
import chex
import jax
import jax.numpy as jnp
except ImportError:
jax = None
jnp = None
chex = None

from jax import grad
from jax.tree_util import Partial as partial

from target_gym.base import EnvParams, EnvState
from target_gym.integration import (
Expand Down Expand Up @@ -89,13 +83,14 @@ def electric_torque_from_rpm(rpm: float, throttle: float, params: CarParams):
return torque


def road_profile(x):
def road_profile(x: float | jax.Array):
"""
More realistic road profile with alternating climbs, plateaus, and descents.
Designed to challenge velocity maintenance.
Elevation changes are on the order of ±100 m.
"""
# Normalize input to kilometers for readability
x = x.squeeze() if isinstance(x, jax.Array) else x
km = x / 1000.0

# Long-term trend: alternating climbs/descents
Expand All @@ -112,7 +107,7 @@ def road_profile(x):

# Small irregularities (avoid perfectly flat sections)
roughness = jnp.sin(km * 3.5) * 2.0 + jnp.sin(km * 11.0) * 1.0
return (trend + plateaus + roughness) * 0
return trend + plateaus + roughness


@partial(jax.jit, static_argnames=["road_profile"])
Expand Down
2 changes: 1 addition & 1 deletion src/target_gym/car/env_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def step_env(
"""
if params is None:
params = self.default_params
throttle = action
throttle = action.reshape(()) # to get scalar and not 1D array

new_state, metrics = compute_next_state(
throttle, state, params, integration_method=self.integration_method
Expand Down
6 changes: 6 additions & 0 deletions src/target_gym/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,14 @@ def step_fn(p, _):
def step_fn(carry, _):
v, p = carry
v_new, p_new, metrics = second_order_step(v, p, h)
v_new = v_new.reshape(v.shape)
p_new = p_new.reshape(p.shape)
return (v_new, p_new), metrics

# if jnp.ndim(velocities) == 0:
# velocities = velocities.reshape((1,))
# if jnp.ndim(positions) == 0:
# positions = positions.reshape((1,))
(new_velocities, new_positions), metrics = jax.lax.scan(
step_fn, (velocities, positions), xs=None, length=n_substeps
)
Expand Down
3 changes: 3 additions & 0 deletions src/target_gym/pc_gym/cstr/env_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def step_env(
"""
if params is None:
params = self.default_params

T_c = action
if not isinstance(action, float):
T_c = action.reshape(())

new_state, metrics = compute_next_state(
T_c, state, params, integration_method=self.integration_method
Expand Down