Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions src/target_gym/interpolator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import os
from typing import Dict, List, Optional, Tuple, Union

import jax.numpy as jnp
import numpy as np
import pandas as pd
from scipy.interpolate import LinearNDInterpolator, interp1d

from target_gym import CSTR, Bike, Car, Plane
from target_gym.runners.utils import run_input_grid


def build_input_interpolator_from_df(
df: pd.DataFrame,
input_names: Union[str, List[str]] = "input",
output_name: str = "final_value",
):
"""
Build interpolator(s) for single- or two-input environments.

Args:
df: DataFrame with columns [input1, input2 (optional), final_value]
input_names: string or list of input column names
output_name: name of the column containing the output (state attribute)

Returns:
interp1d function or dict of interp1d functions
"""
if isinstance(input_names, str):
# Single-input environment
df_sorted = df.sort_values(output_name)
inputs = df_sorted[input_names].to_numpy()
outputs = df_sorted[output_name].to_numpy()

# Check monotonicity
if not (np.all(np.diff(outputs) >= 0) or np.all(np.diff(outputs) <= 0)):
raise ValueError(
f"Output not monotonic for {input_names}, interpolation ambiguous."
)

interpolator = interp1d(
outputs, inputs, bounds_error=False, fill_value=np.nan, kind="linear"
)
return interpolator

elif isinstance(input_names, list) and len(input_names) == 2:
# Two-input environment
input1, input2 = input_names
interpolators: Dict[float, interp1d] = {}
tol = 1e-6

# Build an interpolator for each fixed value of the second input
for val2 in np.unique(df[input2]):
df_fixed = df[np.abs(df[input2] - val2) < tol].sort_values(output_name)
if df_fixed.empty:
continue
outputs = df_fixed[output_name].to_numpy()
inputs = df_fixed[input1].to_numpy()

if not (np.all(np.diff(outputs) >= 0) or np.all(np.diff(outputs) <= 0)):
raise ValueError(
f"Output not monotonic for {input1} at {input2}={val2}, ambiguous."
)

interpolators[val2] = interp1d(
outputs, inputs, bounds_error=False, fill_value=np.nan, kind="linear"
)

return interpolators

else:
raise ValueError("input_names must be a string or a list of 2 strings.")


def get_interpolator_from_run(
run_func: callable,
run_kwargs: dict,
input_names: Union[str, List[str]] = "input",
output_name: str = "final_value",
):
"""
Run the grid function and build interpolators.

Args:
run_func: function returning (final_values, df)
run_kwargs: kwargs to pass to run_func
input_names: single input column or list of two input columns
output_name: column name of the output (state attribute)

Returns:
interp1d function (single-input) or dict of interp1d (two-input)
"""
_, df = run_func(**run_kwargs)
return build_input_interpolator_from_df(
df, input_names=input_names, output_name=output_name
)


# Map each env to its input(s) and output state attribute
ENV_IO_MAPPING = {
Plane: {"input_names": ["power", "stick"], "state_attr": "z"},
Bike: {"input_names": ["power", "stick"], "state_attr": "z"},
Car: {"input_names": ["throttle"], "state_attr": "velocity"},
CSTR: {"input_names": ["T_c"], "state_attr": "T"},
}


def build_env_interpolator(
env_class, env_params, input_levels=None, second_input_levels=None, steps=10_000
):
mapping = ENV_IO_MAPPING[env_class]
input_names = mapping["input_names"]
state_attr = mapping["state_attr"]

env_instance = env_class()
final_values, df = run_input_grid(
input_levels,
env_instance,
env_params,
steps=steps,
input_name=input_names[0],
second_input_levels=second_input_levels,
second_input_name=input_names[1] if len(input_names) == 2 else None,
state_attr=state_attr,
)
if env_class is Plane:
df = df[df["final_value"] > 0]
if len(input_names) == 1:
# Single input: interp1d
return interp1d(
df["final_value"].to_numpy(),
df[input_names[0]].to_numpy(),
bounds_error=False,
fill_value=np.nan,
kind="linear",
)
else:
# Two inputs: only keep second_input=0 for round-trip
df0 = df[df[input_names[1]] == 0.0].sort_values("final_value")
return interp1d(
df0["final_value"].to_numpy(),
df0[input_names[0]].to_numpy(),
bounds_error=False,
fill_value=np.nan,
kind="nearest",
)


def get_interpolator(env_class, env_params, resolution: int = 100, steps: int = 10_000):
mapping = ENV_IO_MAPPING[env_class]
input_names = mapping["input_names"]

# Automatically set input grids
if len(input_names) == 2:

if env_class == Plane:
first_input = jnp.linspace(0, 1.0, resolution)
second_input = jnp.zeros(1)
else:
first_input = jnp.linspace(-1.0, 1.0, resolution)
second_input = jnp.linspace(-1.0, 1.0, resolution)
else:
env_instance = env_class()
try:
min_val = float(env_instance.action_space(env_params).low[0])
max_val = float(env_instance.action_space(env_params).high[0])
if env_class == Car:
min_val = max(min_val, 0.0)
except Exception:
min_val, max_val = -1.0, 1.0
first_input = jnp.linspace(min_val, max_val, resolution)
second_input = None

# Build interpolator
interp = build_env_interpolator(
env_class,
env_params,
input_levels=first_input,
second_input_levels=second_input,
steps=steps,
)
return interp
1 change: 0 additions & 1 deletion src/target_gym/runners/car_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def step_fn(carry, _):


def build_power_interpolator_from_df(df, stick=0.0):
raise NotImplementedError
tol = 1e-6
df_stick = df[np.abs(df["stick"] - stick) < tol]

Expand Down
121 changes: 121 additions & 0 deletions src/target_gym/runners/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os
from typing import Optional, Tuple, Union

import jax
import jax.numpy as jnp
import pandas as pd


def run_constant_policy_final_value(
env,
params,
action: Union[float, Tuple[float, float]],
state_attr: str,
steps: int = 10_000,
key_seed: int = 0,
):
"""
Run a constant policy in a JAX environment and return the final value of a specified state attribute.
Works safely with JAX traced arrays.
"""
key = jax.random.PRNGKey(key_seed)
init_obs, init_state = env.reset_env(key, params)

def step_fn(carry, _):
key, state, done = carry
obs, new_state, reward, new_done, info = env.step_env(
key, state, action, params
)
truncated = new_state.t >= params.max_steps_in_episode
state = jax.lax.cond(done, lambda _: state, lambda _: new_state, operand=None)
done = jnp.logical_or(done, truncated)
value = getattr(new_state, state_attr)
last_value = (
getattr(info["last_state"], state_attr) if "last_state" in info else value
)
return (key, state, done), (value, last_value, done)

(_, final_state, done), (value_hist, last_value_hist, done_hist) = jax.lax.scan(
step_fn, (key, init_state, False), None, length=steps
)

# Get index of first True in done_hist
done_idx = jnp.argmax(done_hist)

# Safely handle first step case using lax.cond
final_value = jax.lax.cond(
done_idx > 0,
lambda idx: last_value_hist[idx - 1],
lambda _: last_value_hist[-1],
operand=done_idx,
)

return final_value


def run_input_grid(
input_levels: jnp.ndarray,
env,
params,
steps: int = 10_000,
input_name: str = "input",
second_input_levels: Optional[jnp.ndarray] = None,
second_input_name: Optional[str] = None,
state_attr: str = "velocity",
) -> Tuple[jnp.ndarray, pd.DataFrame]:
"""
Runs a grid of constant inputs on an environment.

Supports:
- Single-input environments: Car, CSTR
- Two-input environments: Plane, Bike

Args:
input_levels: 1D array of first input (throttle, T_c, power)
env: JAX environment
params: EnvParams
steps: timesteps to run
input_name: name for CSV column of first input
second_input_levels: 1D array of second input (stick), optional
second_input_name: CSV column name for second input, optional
state_attr: which state attribute to track ("velocity", "T", "z", etc.)

Returns:
final_values: jnp array of final state_attr values
df: pandas DataFrame with inputs and final values
"""
if second_input_levels is None:
# Single-input env
def run_one_input(u):
return run_constant_policy_final_value(
env, params, action=u, state_attr=state_attr, steps=steps, key_seed=0
)

final_values = jax.vmap(run_one_input)(input_levels)
df = pd.DataFrame({input_name: input_levels, "final_value": final_values})

else:
# Two-input env
def run_one_first_input(u):
return jax.vmap(
lambda v: run_constant_policy_final_value(
env,
params,
action=(u, v),
state_attr=state_attr,
steps=steps,
key_seed=0,
)
)(second_input_levels)

final_values = jax.vmap(run_one_first_input)(input_levels)

# Flatten arrays for DataFrame
df = pd.DataFrame(
{
input_name: jnp.repeat(input_levels, len(second_input_levels)),
second_input_name: jnp.tile(second_input_levels, len(input_levels)),
"final_value": final_values.flatten(),
}
)
return final_values, df
69 changes: 69 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import jax.numpy as jnp
import numpy as np
import pytest

from target_gym import CSTR, Bike, Car, CarParams, CSTRParams, Plane, PlaneParams
from target_gym.interpolator import (
ENV_IO_MAPPING,
get_interpolator,
)
from target_gym.runners.utils import run_input_grid

STEPS = 10_000
RESOLUTION = 100 # small for test speed


@pytest.mark.parametrize(
"env_class, env_params",
[
(Car, CarParams(max_steps_in_episode=STEPS)),
(CSTR, CSTRParams(max_steps_in_episode=STEPS)),
(Plane, PlaneParams(max_steps_in_episode=STEPS)),
# (Bike, BikeParams(max_steps_in_episode=STEPS)),
],
)
def test_interpolator_round_trip(env_class, env_params):

interpolator = get_interpolator(env_class, env_params, resolution=RESOLUTION)
mapping = ENV_IO_MAPPING[env_class]
input_names = mapping["input_names"]
state_attr = mapping["state_attr"]

# Automatically set input grids
if len(input_names) == 2:

if env_class == Plane:
first_input = jnp.linspace(0, 1.0, RESOLUTION)
second_input = jnp.zeros(1)
else:
first_input = jnp.linspace(-1.0, 1.0, RESOLUTION)
second_input = jnp.linspace(-1.0, 1.0, RESOLUTION)
else:
env_instance = env_class()
try:
min_val = float(env_instance.action_space(env_params).low[0])
max_val = float(env_instance.action_space(env_params).high[0])
if env_class == Car:
min_val = max(min_val, 0.0)
except Exception:
min_val, max_val = -1.0, 1.0
first_input = jnp.linspace(min_val, max_val, RESOLUTION)
second_input = None
# Run the grid to get actual outputs
final_values, df = run_input_grid(
first_input,
env_class(),
env_params,
steps=STEPS,
input_name=input_names[0],
second_input_levels=second_input,
second_input_name=input_names[1] if second_input is not None else None,
state_attr=state_attr,
)

predicted_inputs = interpolator(df["final_value"].to_numpy())
# Mask NaNs
mask = ~np.isnan(predicted_inputs)
np.testing.assert_allclose(
predicted_inputs[mask], df[input_names[0]].to_numpy()[mask], rtol=1e-2
)