Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified figures/bike/3d_altitude.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/bike/power_trajectories.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/car/throttle_trajectories.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/cstr/trajectories.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/plane/3d_altitude.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/plane/power_trajectories.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2,286 changes: 1,398 additions & 888 deletions poetry.lock

Large diffs are not rendered by default.

72 changes: 36 additions & 36 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,29 @@ packages = [
enable = true

[tool.poetry.dependencies]
python = ">=3.10,<3.13"
gymnax = "^0.0.9"
numpy = "^1.26.1"
six = "^1.16.0"
matplotlib = "^3.8.1"
gymnasium = "^1.2.0"
distrax = "^0.1.5"
jax = "<7.0"
moviepy = "^2.2.1"
pygame = "^2.6.1"
pre-commit = { version = "^3.5.0", optional = true }
pytest = { version = "^7.4.3", optional = true }
codecov = { version = "^2.1.13", optional = true }
coverage = { version = "^7.3.2", optional = true }
pytest-cov = { version = "^4.1.0", optional = true }
pytest-readme = { version = "^1.0.2", optional = true }
isort = { version = "^5.12.0", optional = true }
mypy = { version = "^1.6.1", optional = true }
flake8 = { version = "^7.3.0", optional = true }
ruff = { version = "^0.12.11", optional = true }
black = { version = "^25.1.0", optional = true }
stable-baselines3 = { version = "^2.7.0", optional = true }
tensorboard = { version = "^2.20.0", optional = true }
python = ">=3.10"
gymnax = ">=0.0.9"
numpy = ">=1.26.1"
six = ">=1.16.0"
matplotlib = ">=3.8.1"
gymnasium = ">=1.2.0"
distrax = ">=0.1.5"
jax = ">=0.6.2"
moviepy = ">=2.2.1"
pygame = ">=2.6.1"
pre-commit = { version = ">=3.5.0", optional = true }
pytest = { version = ">=7.4.3", optional = true }
codecov = { version = ">=2.1.13", optional = true }
coverage = { version = ">=7.3.2", optional = true }
pytest-cov = { version = ">=4.1.0", optional = true }
pytest-readme = { version = ">=1.0.2", optional = true }
isort = { version = ">=5.12.0", optional = true }
mypy = { version = ">=1.6.1", optional = true }
flake8 = { version = ">=7.3.0", optional = true }
ruff = { version = ">=0.12.11", optional = true }
black = { version = ">=25.1.0", optional = true }
stable-baselines3 = { version = ">=2.7.0", optional = true }
tensorboard = { version = ">=2.20.0", optional = true }

[tool.black]
line-length = 88
Expand Down Expand Up @@ -83,25 +83,25 @@ optional = true
optional = true

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.5.0"
pre-commit = ">=3.5.0"

[tool.poetry.group.test.dependencies]
pytest = "^7.4.3"
codecov = "^2.1.13"
coverage = "^7.3.2"
pytest-cov = "^4.1.0"
pytest-readme = "^1.0.2"
pytest = ">=7.4.3"
codecov = ">=2.1.13"
coverage = ">=7.3.2"
pytest-cov = ">=4.1.0"
pytest-readme = ">=1.0.2"

[tool.poetry.group.lint.dependencies]
isort = "^5.12.0"
mypy = "^1.6.1"
flake8 = "^7.3.0"
ruff = "^0.12.11"
black = "^25.1.0"
isort = ">=5.12.0"
mypy = ">=1.6.1"
flake8 = ">=7.3.0"
ruff = ">=0.12.11"
black = ">=25.1.0"

[tool.poetry.group.agent.dependencies]
stable-baselines3 = "^2.7.0"
tensorboard = "^2.20.0"
stable-baselines3 = ">=2.7.0"
tensorboard = ">=2.20.0"

[tool.poetry.extras]
dev = [
Expand Down
3 changes: 3 additions & 0 deletions src/target_gym/bicycle/env_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def __init__(self, integration_method: str = "rk4_1"):
def default_params(self) -> BikeParams:
return BikeParams()

def compute_reward(self, state, params):
return compute_reward(state, params)

def step_env(
self,
key: chex.PRNGKey,
Expand Down
4 changes: 3 additions & 1 deletion src/target_gym/bicycle/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def world_to_screen(x, y):
hy1 = yf + np.sin(handle_angle + np.pi / 2) * handle_length / 2
hx2 = xf + np.cos(handle_angle - np.pi / 2) * handle_length / 2
hy2 = yf + np.sin(handle_angle - np.pi / 2) * handle_length / 2
pygame.draw.line(surf, (200, 0, 0), (hx1, hy1), (hx2, hy2), 3)
pygame.draw.line(
surf, (200, 0, 0), (hx1.item(), hy1.item()), (hx2.item(), hy2.item()), 3
)

# --- HUD ---
font = pygame.font.SysFont("Arial", 16)
Expand Down
3 changes: 3 additions & 0 deletions src/target_gym/car/env_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __init__(self, integration_method: str = "rk4_1"):
def default_params(self) -> CarParams:
return CarParams()

def compute_reward(self, state, params):
return compute_reward(state, params)

def step_env(
self,
key: chex.PRNGKey,
Expand Down
115 changes: 55 additions & 60 deletions src/target_gym/interpolator.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,80 @@
import os
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Union

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from scipy.interpolate import LinearNDInterpolator, interp1d

from target_gym import CSTR, Bike, Car, Plane
from target_gym.runners.utils import run_input_grid


def jax_interp1d(xq, xp, fp, left=jnp.nan, right=jnp.nan):
"""
Minimal JAX-compatible version of interp1d (linear only).
Assumes xp is sorted and monotonic.
"""
idx = jnp.searchsorted(xp, xq) - 1
idx = jnp.clip(idx, 0, xp.shape[0] - 2)
x0, x1 = xp[idx], xp[idx + 1]
y0, y1 = fp[idx], fp[idx + 1]
slope = (y1 - y0) / (x1 - x0)
yq = y0 + slope * (xq - x0)

# Handle out-of-bounds
yq = jnp.where(xq < xp[0], left, yq)
yq = jnp.where(xq > xp[-1], right, yq)
return yq


def build_input_interpolator_from_df(
df: pd.DataFrame,
input_names: Union[str, List[str]] = "input",
output_name: str = "final_value",
):
"""
Build interpolator(s) for single- or two-input environments.

Args:
df: DataFrame with columns [input1, input2 (optional), final_value]
input_names: string or list of input column names
output_name: name of the column containing the output (state attribute)

Returns:
interp1d function or dict of interp1d functions
Build JAX-compatible interpolators for single- or two-input environments.
"""
if isinstance(input_names, str):
# Single-input environment
df_sorted = df.sort_values(output_name)
inputs = df_sorted[input_names].to_numpy()
outputs = df_sorted[output_name].to_numpy()
inputs = jnp.array(df_sorted[input_names].to_numpy())
outputs = jnp.array(df_sorted[output_name].to_numpy())

# Check monotonicity
if not (np.all(np.diff(outputs) >= 0) or np.all(np.diff(outputs) <= 0)):
# Check monotonicity (outside JAX, since it's a one-time check)
diffs = jnp.diff(outputs)
if not (jnp.all(diffs >= 0) | jnp.all(diffs <= 0)):
raise ValueError(
f"Output not monotonic for {input_names}, interpolation ambiguous."
)

interpolator = interp1d(
outputs, inputs, bounds_error=False, fill_value=np.nan, kind="linear"
)
def interpolator(query_outputs):
return jax_interp1d(query_outputs, outputs, inputs)

return interpolator

elif isinstance(input_names, list) and len(input_names) == 2:
# Two-input environment
input1, input2 = input_names
interpolators: Dict[float, interp1d] = {}
interpolators: Dict[float, callable] = {}
tol = 1e-6

# Build an interpolator for each fixed value of the second input
for val2 in np.unique(df[input2]):
df_fixed = df[np.abs(df[input2] - val2) < tol].sort_values(output_name)
for val2 in jnp.unique(jnp.array(df[input2].to_numpy())):
df_fixed = df[abs(df[input2] - float(val2)) < tol].sort_values(output_name)
if df_fixed.empty:
continue
outputs = df_fixed[output_name].to_numpy()
inputs = df_fixed[input1].to_numpy()
outputs = jnp.array(df_fixed[output_name].to_numpy())
inputs = jnp.array(df_fixed[input1].to_numpy())

if not (np.all(np.diff(outputs) >= 0) or np.all(np.diff(outputs) <= 0)):
diffs = jnp.diff(outputs)
if not (jnp.all(diffs >= 0) | jnp.all(diffs <= 0)):
raise ValueError(
f"Output not monotonic for {input1} at {input2}={val2}, ambiguous."
)

interpolators[val2] = interp1d(
outputs, inputs, bounds_error=False, fill_value=np.nan, kind="linear"
)
def make_interp(xp, fp):
return lambda xq: jax_interp1d(xq, xp, fp)

interpolators[float(val2)] = make_interp(outputs, inputs)

return interpolators

Expand All @@ -79,16 +89,7 @@ def get_interpolator_from_run(
output_name: str = "final_value",
):
"""
Run the grid function and build interpolators.

Args:
run_func: function returning (final_values, df)
run_kwargs: kwargs to pass to run_func
input_names: single input column or list of two input columns
output_name: column name of the output (state attribute)

Returns:
interp1d function (single-input) or dict of interp1d (two-input)
Run the grid function and build interpolators (JAX version).
"""
_, df = run_func(**run_kwargs)
return build_input_interpolator_from_df(
Expand Down Expand Up @@ -123,55 +124,49 @@ def build_env_interpolator(
second_input_name=input_names[1] if len(input_names) == 2 else None,
state_attr=state_attr,
)

if env_class is Plane:
df = df[df["final_value"] > 0]

if len(input_names) == 1:
# Single input: interp1d
return interp1d(
df["final_value"].to_numpy(),
df[input_names[0]].to_numpy(),
bounds_error=False,
fill_value=np.nan,
kind="linear",
return lambda q: jax_interp1d(
q,
jnp.array(df["final_value"].to_numpy()),
jnp.array(df[input_names[0]].to_numpy()),
)
else:
# Two inputs: only keep second_input=0 for round-trip
df0 = df[df[input_names[1]] == 0.0].sort_values("final_value")
return interp1d(
df0["final_value"].to_numpy(),
df0[input_names[0]].to_numpy(),
bounds_error=False,
fill_value=np.nan,
kind="nearest",
return lambda q: jax_interp1d(
q,
jnp.array(df0["final_value"].to_numpy()),
jnp.array(df0[input_names[0]].to_numpy()),
)


def get_interpolator(env_class, env_params, resolution: int = 100, steps: int = 10_000):
mapping = ENV_IO_MAPPING[env_class]
input_names = mapping["input_names"]

# Automatically set input grids
if len(input_names) == 2:

if env_class == Plane:
first_input = jnp.linspace(0, 1.0, resolution)
first_input = jnp.linspace(-1.0, 1.0, resolution)
second_input = jnp.zeros(1)
else:
first_input = jnp.linspace(-1.0, 1.0, resolution)
second_input = jnp.linspace(-1.0, 1.0, resolution)
else:
env_instance = env_class()
try:
min_val = float(env_instance.action_space(env_params).low[0])
max_val = float(env_instance.action_space(env_params).high[0])
bounds = env_instance.action_space(env_params)
min_val = float(bounds.low[0])
max_val = float(bounds.high[0])
if env_class == Car:
min_val = max(min_val, 0.0)
except Exception:
min_val, max_val = -1.0, 1.0
first_input = jnp.linspace(min_val, max_val, resolution)
second_input = None

# Build interpolator
interp = build_env_interpolator(
env_class,
env_params,
Expand Down
3 changes: 3 additions & 0 deletions src/target_gym/pc_gym/cstr/env_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def __init__(self, integration_method: str = "rk4_1"):
def default_params(self) -> CSTRParams:
return CSTRParams()

def compute_reward(self, state, params):
return compute_reward(state, params)

def step_env(
self,
key: chex.PRNGKey,
Expand Down
9 changes: 7 additions & 2 deletions src/target_gym/plane/env_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(self, integration_method: str = "rk4_1"):
def default_params(self) -> PlaneParams:
return PlaneParams()

def compute_reward(self, state, params):
return compute_reward(state, params)

def step_env(
self,
key: chex.PRNGKey,
Expand All @@ -49,13 +52,15 @@ def step_env(
"""
if params is None:
params = self.default_params

power, stick = action
power = (power + 1) / 2 # map from [-1, 1] to [0, 1]
stick = jnp.deg2rad(stick * 15) # radians

new_state, metrics = compute_next_state(
power, stick, state, params, integration_method=self.integration_method
)
reward = compute_reward(new_state, params, xp=jnp)
reward = self.compute_reward(new_state, params)
terminated, truncated = check_is_terminal(new_state, params, xp=jnp)
done = terminated | truncated

Expand Down Expand Up @@ -167,7 +172,7 @@ def save_video(
def action_space(self, params: PlaneParams | None = None) -> spaces.Discrete:
"""Action space of the environment."""
return spaces.Box(
low=jnp.array([0.0, -1.0]),
low=jnp.array([-1.0, -1.0]),
high=jnp.array([1.0, 1.0]),
shape=(2,),
dtype=jnp.float32,
Expand Down
Loading