From 00c75c465e11dc76062719d126d98215d0b82aa2 Mon Sep 17 00:00:00 2001 From: YannBerthelot Date: Sun, 21 Sep 2025 22:25:48 +0200 Subject: [PATCH] added interpolation utils --- src/target_gym/interpolator.py | 182 +++++++++++++++++++++++++++ src/target_gym/runners/car_runner.py | 1 - src/target_gym/runners/utils.py | 121 ++++++++++++++++++ tests/test_interpolation.py | 69 ++++++++++ 4 files changed, 372 insertions(+), 1 deletion(-) create mode 100644 src/target_gym/interpolator.py create mode 100644 src/target_gym/runners/utils.py create mode 100644 tests/test_interpolation.py diff --git a/src/target_gym/interpolator.py b/src/target_gym/interpolator.py new file mode 100644 index 0000000..3677d80 --- /dev/null +++ b/src/target_gym/interpolator.py @@ -0,0 +1,182 @@ +import os +from typing import Dict, List, Optional, Tuple, Union + +import jax.numpy as jnp +import numpy as np +import pandas as pd +from scipy.interpolate import LinearNDInterpolator, interp1d + +from target_gym import CSTR, Bike, Car, Plane +from target_gym.runners.utils import run_input_grid + + +def build_input_interpolator_from_df( + df: pd.DataFrame, + input_names: Union[str, List[str]] = "input", + output_name: str = "final_value", +): + """ + Build interpolator(s) for single- or two-input environments. + + Args: + df: DataFrame with columns [input1, input2 (optional), final_value] + input_names: string or list of input column names + output_name: name of the column containing the output (state attribute) + + Returns: + interp1d function or dict of interp1d functions + """ + if isinstance(input_names, str): + # Single-input environment + df_sorted = df.sort_values(output_name) + inputs = df_sorted[input_names].to_numpy() + outputs = df_sorted[output_name].to_numpy() + + # Check monotonicity + if not (np.all(np.diff(outputs) >= 0) or np.all(np.diff(outputs) <= 0)): + raise ValueError( + f"Output not monotonic for {input_names}, interpolation ambiguous." + ) + + interpolator = interp1d( + outputs, inputs, bounds_error=False, fill_value=np.nan, kind="linear" + ) + return interpolator + + elif isinstance(input_names, list) and len(input_names) == 2: + # Two-input environment + input1, input2 = input_names + interpolators: Dict[float, interp1d] = {} + tol = 1e-6 + + # Build an interpolator for each fixed value of the second input + for val2 in np.unique(df[input2]): + df_fixed = df[np.abs(df[input2] - val2) < tol].sort_values(output_name) + if df_fixed.empty: + continue + outputs = df_fixed[output_name].to_numpy() + inputs = df_fixed[input1].to_numpy() + + if not (np.all(np.diff(outputs) >= 0) or np.all(np.diff(outputs) <= 0)): + raise ValueError( + f"Output not monotonic for {input1} at {input2}={val2}, ambiguous." + ) + + interpolators[val2] = interp1d( + outputs, inputs, bounds_error=False, fill_value=np.nan, kind="linear" + ) + + return interpolators + + else: + raise ValueError("input_names must be a string or a list of 2 strings.") + + +def get_interpolator_from_run( + run_func: callable, + run_kwargs: dict, + input_names: Union[str, List[str]] = "input", + output_name: str = "final_value", +): + """ + Run the grid function and build interpolators. + + Args: + run_func: function returning (final_values, df) + run_kwargs: kwargs to pass to run_func + input_names: single input column or list of two input columns + output_name: column name of the output (state attribute) + + Returns: + interp1d function (single-input) or dict of interp1d (two-input) + """ + _, df = run_func(**run_kwargs) + return build_input_interpolator_from_df( + df, input_names=input_names, output_name=output_name + ) + + +# Map each env to its input(s) and output state attribute +ENV_IO_MAPPING = { + Plane: {"input_names": ["power", "stick"], "state_attr": "z"}, + Bike: {"input_names": ["power", "stick"], "state_attr": "z"}, + Car: {"input_names": ["throttle"], "state_attr": "velocity"}, + CSTR: {"input_names": ["T_c"], "state_attr": "T"}, +} + + +def build_env_interpolator( + env_class, env_params, input_levels=None, second_input_levels=None, steps=10_000 +): + mapping = ENV_IO_MAPPING[env_class] + input_names = mapping["input_names"] + state_attr = mapping["state_attr"] + + env_instance = env_class() + final_values, df = run_input_grid( + input_levels, + env_instance, + env_params, + steps=steps, + input_name=input_names[0], + second_input_levels=second_input_levels, + second_input_name=input_names[1] if len(input_names) == 2 else None, + state_attr=state_attr, + ) + if env_class is Plane: + df = df[df["final_value"] > 0] + if len(input_names) == 1: + # Single input: interp1d + return interp1d( + df["final_value"].to_numpy(), + df[input_names[0]].to_numpy(), + bounds_error=False, + fill_value=np.nan, + kind="linear", + ) + else: + # Two inputs: only keep second_input=0 for round-trip + df0 = df[df[input_names[1]] == 0.0].sort_values("final_value") + return interp1d( + df0["final_value"].to_numpy(), + df0[input_names[0]].to_numpy(), + bounds_error=False, + fill_value=np.nan, + kind="nearest", + ) + + +def get_interpolator(env_class, env_params, resolution: int = 100, steps: int = 10_000): + mapping = ENV_IO_MAPPING[env_class] + input_names = mapping["input_names"] + + # Automatically set input grids + if len(input_names) == 2: + + if env_class == Plane: + first_input = jnp.linspace(0, 1.0, resolution) + second_input = jnp.zeros(1) + else: + first_input = jnp.linspace(-1.0, 1.0, resolution) + second_input = jnp.linspace(-1.0, 1.0, resolution) + else: + env_instance = env_class() + try: + min_val = float(env_instance.action_space(env_params).low[0]) + max_val = float(env_instance.action_space(env_params).high[0]) + if env_class == Car: + min_val = max(min_val, 0.0) + except Exception: + min_val, max_val = -1.0, 1.0 + first_input = jnp.linspace(min_val, max_val, resolution) + second_input = None + + # Build interpolator + interp = build_env_interpolator( + env_class, + env_params, + input_levels=first_input, + second_input_levels=second_input, + steps=steps, + ) + return interp diff --git a/src/target_gym/runners/car_runner.py b/src/target_gym/runners/car_runner.py index d1090ae..f38d7b4 100644 --- a/src/target_gym/runners/car_runner.py +++ b/src/target_gym/runners/car_runner.py @@ -64,7 +64,6 @@ def step_fn(carry, _): def build_power_interpolator_from_df(df, stick=0.0): - raise NotImplementedError tol = 1e-6 df_stick = df[np.abs(df["stick"] - stick) < tol] diff --git a/src/target_gym/runners/utils.py b/src/target_gym/runners/utils.py new file mode 100644 index 0000000..5abf351 --- /dev/null +++ b/src/target_gym/runners/utils.py @@ -0,0 +1,121 @@ +import os +from typing import Optional, Tuple, Union + +import jax +import jax.numpy as jnp +import pandas as pd + + +def run_constant_policy_final_value( + env, + params, + action: Union[float, Tuple[float, float]], + state_attr: str, + steps: int = 10_000, + key_seed: int = 0, +): + """ + Run a constant policy in a JAX environment and return the final value of a specified state attribute. + Works safely with JAX traced arrays. + """ + key = jax.random.PRNGKey(key_seed) + init_obs, init_state = env.reset_env(key, params) + + def step_fn(carry, _): + key, state, done = carry + obs, new_state, reward, new_done, info = env.step_env( + key, state, action, params + ) + truncated = new_state.t >= params.max_steps_in_episode + state = jax.lax.cond(done, lambda _: state, lambda _: new_state, operand=None) + done = jnp.logical_or(done, truncated) + value = getattr(new_state, state_attr) + last_value = ( + getattr(info["last_state"], state_attr) if "last_state" in info else value + ) + return (key, state, done), (value, last_value, done) + + (_, final_state, done), (value_hist, last_value_hist, done_hist) = jax.lax.scan( + step_fn, (key, init_state, False), None, length=steps + ) + + # Get index of first True in done_hist + done_idx = jnp.argmax(done_hist) + + # Safely handle first step case using lax.cond + final_value = jax.lax.cond( + done_idx > 0, + lambda idx: last_value_hist[idx - 1], + lambda _: last_value_hist[-1], + operand=done_idx, + ) + + return final_value + + +def run_input_grid( + input_levels: jnp.ndarray, + env, + params, + steps: int = 10_000, + input_name: str = "input", + second_input_levels: Optional[jnp.ndarray] = None, + second_input_name: Optional[str] = None, + state_attr: str = "velocity", +) -> Tuple[jnp.ndarray, pd.DataFrame]: + """ + Runs a grid of constant inputs on an environment. + + Supports: + - Single-input environments: Car, CSTR + - Two-input environments: Plane, Bike + + Args: + input_levels: 1D array of first input (throttle, T_c, power) + env: JAX environment + params: EnvParams + steps: timesteps to run + input_name: name for CSV column of first input + second_input_levels: 1D array of second input (stick), optional + second_input_name: CSV column name for second input, optional + state_attr: which state attribute to track ("velocity", "T", "z", etc.) + + Returns: + final_values: jnp array of final state_attr values + df: pandas DataFrame with inputs and final values + """ + if second_input_levels is None: + # Single-input env + def run_one_input(u): + return run_constant_policy_final_value( + env, params, action=u, state_attr=state_attr, steps=steps, key_seed=0 + ) + + final_values = jax.vmap(run_one_input)(input_levels) + df = pd.DataFrame({input_name: input_levels, "final_value": final_values}) + + else: + # Two-input env + def run_one_first_input(u): + return jax.vmap( + lambda v: run_constant_policy_final_value( + env, + params, + action=(u, v), + state_attr=state_attr, + steps=steps, + key_seed=0, + ) + )(second_input_levels) + + final_values = jax.vmap(run_one_first_input)(input_levels) + + # Flatten arrays for DataFrame + df = pd.DataFrame( + { + input_name: jnp.repeat(input_levels, len(second_input_levels)), + second_input_name: jnp.tile(second_input_levels, len(input_levels)), + "final_value": final_values.flatten(), + } + ) + return final_values, df diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py new file mode 100644 index 0000000..cedc781 --- /dev/null +++ b/tests/test_interpolation.py @@ -0,0 +1,69 @@ +import jax.numpy as jnp +import numpy as np +import pytest + +from target_gym import CSTR, Bike, Car, CarParams, CSTRParams, Plane, PlaneParams +from target_gym.interpolator import ( + ENV_IO_MAPPING, + get_interpolator, +) +from target_gym.runners.utils import run_input_grid + +STEPS = 10_000 +RESOLUTION = 100 # small for test speed + + +@pytest.mark.parametrize( + "env_class, env_params", + [ + (Car, CarParams(max_steps_in_episode=STEPS)), + (CSTR, CSTRParams(max_steps_in_episode=STEPS)), + (Plane, PlaneParams(max_steps_in_episode=STEPS)), + # (Bike, BikeParams(max_steps_in_episode=STEPS)), + ], +) +def test_interpolator_round_trip(env_class, env_params): + + interpolator = get_interpolator(env_class, env_params, resolution=RESOLUTION) + mapping = ENV_IO_MAPPING[env_class] + input_names = mapping["input_names"] + state_attr = mapping["state_attr"] + + # Automatically set input grids + if len(input_names) == 2: + + if env_class == Plane: + first_input = jnp.linspace(0, 1.0, RESOLUTION) + second_input = jnp.zeros(1) + else: + first_input = jnp.linspace(-1.0, 1.0, RESOLUTION) + second_input = jnp.linspace(-1.0, 1.0, RESOLUTION) + else: + env_instance = env_class() + try: + min_val = float(env_instance.action_space(env_params).low[0]) + max_val = float(env_instance.action_space(env_params).high[0]) + if env_class == Car: + min_val = max(min_val, 0.0) + except Exception: + min_val, max_val = -1.0, 1.0 + first_input = jnp.linspace(min_val, max_val, RESOLUTION) + second_input = None + # Run the grid to get actual outputs + final_values, df = run_input_grid( + first_input, + env_class(), + env_params, + steps=STEPS, + input_name=input_names[0], + second_input_levels=second_input, + second_input_name=input_names[1] if second_input is not None else None, + state_attr=state_attr, + ) + + predicted_inputs = interpolator(df["final_value"].to_numpy()) + # Mask NaNs + mask = ~np.isnan(predicted_inputs) + np.testing.assert_allclose( + predicted_inputs[mask], df[input_names[0]].to_numpy()[mask], rtol=1e-2 + )