diff --git a/README.md b/README.md index 578638fd..8025b7e4 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,7 @@ QDax currently supports the following algorithms: | [Policy Gradient Assisted MAP-Elites (PGA-ME)](https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pgame.ipynb) | | [DCRL-ME](https://arxiv.org/abs/2401.08632) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/dcrlme.ipynb) | | [QDPG](https://arxiv.org/abs/2006.08505) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/qdpg.ipynb) | +| [PGA-ME with Occupancy](https://doi.org/10.1145/3712256.3726337) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pgame_occupancy.ipynb) | | [CMA-ME](https://arxiv.org/pdf/1912.02400.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmame.ipynb) | | [OMG-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/omgmega.ipynb) | | [CMA-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmamega.ipynb) | diff --git a/examples/pgame_occupancy.ipynb b/examples/pgame_occupancy.ipynb new file mode 100644 index 00000000..bb82d7c0 --- /dev/null +++ b/examples/pgame_occupancy.ipynb @@ -0,0 +1,839 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "lehzCNnsOXZk" + }, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zV-aeDlbOXZn" + }, + "source": [ + "# Optimizing with Occupancy-CVT MAP-Elites in JAX\n", + "\n", + "This notebook shows how to use QDax to find diverse and performing controllers using the [occupancy measure as a generic behavioural descriptor](https://dl.acm.org/doi/10.1145/3712256.3726337).\n", + "It can be run locally or on Google Colab. We recommend to use a GPU.\n", + "\n", + "This notebook uses code from the QDax notebook on CVT MAP-Elites\tand Policy Gradient Assisted MAP-Elites (PGA-ME) for those elements of the process that remain unchanged. The main contribution is the setup of the occupancy measure as a behavioural descriptor that utilised a full rollout as data. It will show:\n", + "\n", + "- the usual PGA-MAP-Elites setup\n", + "- how to create the behavioural descriptor using the double CVT-trick\n", + "- how to create the centroids in the occupancy space\n", + "- an example solving the Point-Maze and some evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5nectSbCOXZo" + }, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HI8sKYtfPntC" + }, + "outputs": [], + "source": [ + "#brax version 0.12.4 has changed the API\n", + "!pip install brax==0.12.3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "f9M-zOlhOXZq" + }, + "outputs": [], + "source": [ + "!pip install ipympl | tail -n 1\n", + "# %matplotlib widget\n", + "# from google.colab import output\n", + "# output.enable_custom_widget_manager()\n", + "\n", + "import os\n", + "\n", + "from IPython.display import clear_output\n", + "import functools\n", + "import time\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "import numpy as np\n", + "from numpy.random import RandomState\n", + "from sklearn.cluster import KMeans\n", + "\n", + "from qdax.core.map_elites import MAPElites\n", + "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire\n", + "import qdax.tasks.brax.v1 as environments\n", + "from qdax.tasks.brax.v1.env_creators import scoring_function_brax_envs as scoring_function\n", + "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", + "from qdax.core.neuroevolution.networks.networks import MLP\n", + "from qdax.core.emitters.mutation_operators import isoline_variation\n", + "from qdax.core.emitters.standard_emitters import MixingEmitter\n", + "from qdax.core.emitters.pga_me_emitter import PGAMEConfig, PGAMEEmitter\n", + "#from qdax.utils.plotting import plot_map_elites_results\n", + "\n", + "from qdax.utils.metrics import CSVLogger, default_qd_metrics\n", + "\n", + "from jax.flatten_util import ravel_pytree\n", + "\n", + "#from IPython.display import HTML\n", + "#from brax.v1.io import html\n", + "\n", + "\n", + "if \"COLAB_TPU_ADDR\" in os.environ:\n", + " from jax.tools import colab_tpu\n", + " colab_tpu.setup_tpu()\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0oZ2i5-9OXZq" + }, + "outputs": [], + "source": [ + "#@title QD Training Definitions Fields\n", + "#@markdown ---\n", + "batch_size = 100 #@param {type:\"number\"}\n", + "env_name = 'pointmaze'#@param['pointmaze', 'anttrap', 'walker2d_uni']\n", + "episode_length = 200 #@param {type:\"integer\"}\n", + "num_iterations = 2000 #@param {type:\"integer\"}\n", + "seed = 42 #@param {type:\"integer\"}\n", + "policy_hidden_layer_sizes = (64, 64) #@param {type:\"raw\"}\n", + "iso_sigma = 0.005 #@param {type:\"number\"}\n", + "line_sigma = 0.05 #@param {type:\"number\"}\n", + "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", + "num_centroids_stateaction = 512 #@param {type:\"integer\"}\n", + "num_centroids_repertoire = 512 #@param {type:\"integer\"}\n", + "min_descriptor = -1.0 #@param {type:\"number\"}\n", + "max_descriptor = 1.0 #@param {type:\"number\"}\n", + "#@markdown ---" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EFHQtwUtOXZr" + }, + "source": [ + "## The standard MAP-Elites setup\n", + "\n", + "This follows the notebook for CVT-Elites:\n", + "\n", + "Define the environment in which the policies will be trained. In this notebook, we focus on controllers learning to move a robot in a physical simulation. We also define the shared policy, that every individual in the population will use. Once the policy is defined, all individuals are defined by their parameters, that corresponds to their genotype." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fQq-tLvaOXZr" + }, + "outputs": [], + "source": [ + "# Init environment\n", + "env = environments.create(env_name, episode_length=episode_length)\n", + "reset_fn = jax.jit(env.reset)\n", + "\n", + "# Init a random key\n", + "key = jax.random.key(seed)\n", + "\n", + "# Init policy network\n", + "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", + "policy_network = MLP(\n", + " layer_sizes=policy_layer_sizes,\n", + " kernel_init=jax.nn.initializers.lecun_uniform(),\n", + " final_activation=jnp.tanh,\n", + ")\n", + "\n", + "# Init population of controllers\n", + "key, subkey = jax.random.split(key)\n", + "keys = jax.random.split(subkey, num=batch_size)\n", + "fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n", + "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nHBEKX0COXZs" + }, + "source": [ + "Now that the environment and policy has been defined, it is necessary to define a function that describes how the policy must be used to interact with the environment and to store transition data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Olgn-NXNOXZs" + }, + "outputs": [], + "source": [ + "# Define the function to play a step with the policy in the environment\n", + "def play_step_fn(\n", + " env_state,\n", + " policy_params,\n", + " key,\n", + "):\n", + " \"\"\"\n", + " Play an environment step and return the updated state and the transition.\n", + " \"\"\"\n", + " actions = policy_network.apply(policy_params, env_state.obs)\n", + " state_desc = env_state.info[\"state_descriptor\"]\n", + " next_state = env.step(env_state, actions)\n", + " transition = QDTransition(\n", + " obs=env_state.obs,\n", + " next_obs=next_state.obs,\n", + " rewards=next_state.reward,\n", + " dones=next_state.done,\n", + " actions=actions,\n", + " truncations=next_state.info[\"truncation\"],\n", + " state_desc=state_desc,\n", + " next_state_desc=next_state.info[\"state_descriptor\"],\n", + " )\n", + "\n", + " return next_state, policy_params, key, transition" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wiDpez5KOXZt" + }, + "source": [ + "The emitter is used to evolve the population at each mutation step. This follows the PGA-MAP-Elites setup." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4q3eH7LhOXZt" + }, + "outputs": [], + "source": [ + "###### PG Parameters\n", + "proportion_mutation_ga = 0.5 # @param {type:\"number\"}\n", + "# TD3 params\n", + "replay_buffer_size = 1000000 # @param {type:\"number\"}\n", + "critic_hidden_layer_size = policy_hidden_layer_sizes # @param {type:\"raw\"}\n", + "critic_learning_rate = 3e-4 # @param {type:\"number\"}\n", + "greedy_learning_rate = 3e-4 # @param {type:\"number\"}\n", + "policy_learning_rate = 1e-3 # @param {type:\"number\"}\n", + "noise_clip = 0.5 # @param {type:\"number\"}\n", + "policy_noise = 0.2 # @param {type:\"number\"}\n", + "discount = 0.99 # @param {type:\"number\"}\n", + "reward_scaling = 1.0 # @param {type:\"number\"}\n", + "transitions_batch_size = 256 # @param {type:\"number\"}\n", + "soft_tau_update = 0.005 # @param {type:\"number\"}\n", + "num_critic_training_steps = 300 # @param {type:\"number\"}\n", + "num_pg_training_steps = 100 # @param {type:\"number\"}\n", + "policy_delay = 2 # @param {type:\"number\"}\n", + "# @markdown ---\n", + "# Define the PG-emitter config\n", + "pga_emitter_config = PGAMEConfig(\n", + " env_batch_size=batch_size,\n", + " batch_size=transitions_batch_size,\n", + " proportion_mutation_ga=proportion_mutation_ga,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " critic_learning_rate=critic_learning_rate,\n", + " greedy_learning_rate=greedy_learning_rate,\n", + " policy_learning_rate=policy_learning_rate,\n", + " noise_clip=noise_clip,\n", + " policy_noise=policy_noise,\n", + " discount=discount,\n", + " reward_scaling=reward_scaling,\n", + " replay_buffer_size=replay_buffer_size,\n", + " soft_tau_update=soft_tau_update,\n", + " num_critic_training_steps=num_critic_training_steps,\n", + " num_pg_training_steps=num_pg_training_steps,\n", + " policy_delay=policy_delay,\n", + ")\n", + "# Define emitter\n", + "variation_fn = functools.partial(isoline_variation, iso_sigma=0.05, line_sigma=0.1)\n", + "pg_emitter = PGAMEEmitter(\n", + " config=pga_emitter_config,\n", + " policy_network=policy_network,\n", + " env=env,\n", + " variation_fn=variation_fn,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iEcNmeRqOXZt" + }, + "source": [ + "## Occupancy CVT-Elites\n", + "\n", + "This is the minimal code to run the double CVT-trick to use the occupancy measure as a behavioural descriptor. Here, define functions (using JAX-logic) that 1) create a behavioural descriptor based on a CVT of the state-action space and another CVT of the occupancy space, and 2) create a CVT of the occupancy space.\n", + "\n", + "The method needs two CVTs: one standard CVT for the state-action space and another for the occupancy space. This space is not a \"box\" (i.e. [a_1, b_1] x ... x [a_n, b_n]) but rather a space of 0 = x_i = 1 such that sum(x_i) = 1.\n", + "The behavioural descriptor takes an episode, ignores all data that is gathered after the episode finishes, measures the visitation of the \"bins\" of the state-action space according to the first CVT and assigns this measure to a bin in the second CVT." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tm7KcmYUR4T9" + }, + "outputs": [], + "source": [ + "def get_distr(neighbors, valid_steps, num_centroids):\n", + " # Set non-valid (steps that have been taken after the episode terminated) as \"unreachable\"/\"uncountable\"\n", + " idx = jnp.maximum(jnp.arange(num_centroids), valid_steps)\n", + " neighbors = neighbors.at[idx].set(\n", + " num_centroids\n", + " )\n", + " # Count how often one \"bin\" is visited\n", + " return jnp.bincount(neighbors.flatten(), length=num_centroids).astype(float)\n", + "\n", + "# fast calculation of nearest neighbours\n", + "@functools.partial(jax.jit, static_argnames=[\"num_centroids\", \"recall_target\"])\n", + "def get_neighbors(centroids, centroids_sq, y, num_centroids, recall_target):\n", + " y_square = jnp.sum(y**2, axis=1, keepdims=True)\n", + " dists = centroids_sq + y_square.T * jnp.dot(centroids, y.T)\n", + " neighbor_dists, neighbors = jax.lax.approx_min_k(\n", + " dists.T, k=1, recall_target=recall_target\n", + " )\n", + " return neighbors\n", + "\n", + "def bd_distribution_fn_withcentroids(\n", + " data: QDTransition,\n", + " mask: jnp.ndarray,\n", + " ft_centroids: jnp.ndarray,\n", + " ft_centroids_sq: jnp.ndarray,\n", + " num_centroids: int,\n", + "):\n", + " # tanh is a lazy way of making all observation data bounded in [-1,1]\n", + " # concat obs and action to get the full state-action space\n", + " inp = jnp.tanh(jnp.concat([data.obs, data.actions], axis=-1))\n", + "\n", + " # find nearest neighbours to assign them to the appropriate bin\n", + " all_neighbors = jax.vmap(get_neighbors, in_axes=(None, None, 0, None, None))(\n", + " ft_centroids, ft_centroids_sq, inp, num_centroids, 0.95\n", + " )\n", + " # find non-terminal steps\n", + " all_valid_steps = (1 - mask).sum(axis=1).astype(int)\n", + "\n", + " # only use those, that are non-terminal\n", + " nns = jax.vmap(get_distr, in_axes=(0, 0, None))(\n", + " all_neighbors, all_valid_steps, num_centroids\n", + " )\n", + " # normalise to sum == 1\n", + " dists = nns / jnp.expand_dims(all_valid_steps, axis=1)\n", + " return dists\n", + "\n", + "\n", + "def compute_simplex_centroids(\n", + " num_descriptors,\n", + " num_init_cvt_samples,\n", + " num_centroids,\n", + " key,\n", + "):\n", + " # since all samples are x_i > 0 and sum(x_i) == 1, sample from the direchtlet distribution\n", + " x = jax.random.dirichlet(\n", + " key, jnp.ones(num_descriptors), shape=(num_init_cvt_samples,)\n", + " )\n", + " k_means = KMeans(\n", + " init=\"k-means++\",\n", + " n_clusters=num_centroids,\n", + " n_init=1,\n", + " random_state=RandomState(42),\n", + " )\n", + " k_means.fit(x)\n", + " return jnp.asarray(k_means.cluster_centers_)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "POKCzGHTR4kB" + }, + "source": [ + "Now, create the CVT and the behavioural descriptor function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QSpzUxnuR2zN" + }, + "outputs": [], + "source": [ + "# Compute the centroids for the state-action space (1st CVT)\n", + "key, subkey = jax.random.split(key)\n", + "ft_centroids = compute_cvt_centroids(\n", + " num_descriptors=env.observation_size + env.action_size,\n", + " num_init_cvt_samples=num_init_cvt_samples,\n", + " num_centroids=num_centroids_repertoire,\n", + " minval=-1,\n", + " maxval=1,\n", + " key=subkey,\n", + ")\n", + "\n", + "# calculate this square sum once only\n", + "ft_centroids_sq = jnp.sum(ft_centroids**2, axis=1, keepdims=True)\n", + "bd_distribution = functools.partial(\n", + " bd_distribution_fn_withcentroids,\n", + " ft_centroids=ft_centroids,\n", + " ft_centroids_sq=ft_centroids_sq,\n", + " num_centroids=num_centroids_repertoire,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Sq0QgfiQW0kG" + }, + "outputs": [], + "source": [ + "# Compute the centroids for the occupancy measure space (2nd CVT)\n", + "key, subkey = jax.random.split(key)\n", + "elites_centroids = compute_simplex_centroids(\n", + " num_centroids_repertoire,\n", + " num_init_cvt_samples,\n", + " num_centroids_stateaction,\n", + " subkey,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5ynAoTD_OXZs" + }, + "source": [ + "Now that the behavioural descriptor function `bd_distribution` is defined, the scoring function can be defined.\n", + "\n", + "The scoring function is used in the evaluation step to determine the fitness and descriptor of each individual." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zWGNKySGOXZt" + }, + "outputs": [], + "source": [ + "# Prepare the scoring function\n", + "scoring_fn = functools.partial(\n", + " scoring_function,\n", + " episode_length=episode_length,\n", + " play_reset_fn=reset_fn,\n", + " play_step_fn=play_step_fn,\n", + " descriptor_extractor=bd_distribution,\n", + ")\n", + "\n", + "# Get minimum reward value to make sure qd_score are positive\n", + "reward_offset = environments.reward_offset[env_name]\n", + "\n", + "# Define a metrics function\n", + "metrics_function = functools.partial(\n", + " default_qd_metrics,\n", + " qd_offset=reward_offset * episode_length,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l3I_sEx-HHmn" + }, + "source": [ + "Now the standard MAPElites logic can be instantiated." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "otm3MoUiOXZu" + }, + "outputs": [], + "source": [ + "# Instantiate MAP-Elites\n", + "map_elites = MAPElites(\n", + " scoring_function=scoring_fn,\n", + " emitter=pg_emitter,\n", + " metrics_function=metrics_function,\n", + ")\n", + "\n", + "# Compute initial repertoire and emitter state\n", + "key, subkey = jax.random.split(key)\n", + "repertoire, emitter_state, init_metrics = map_elites.init(init_variables, elites_centroids, subkey)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pxEWzQ_UOXZu" + }, + "source": [ + "## Launch MAP-Elites iterations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yFlGDrzZOXZu" + }, + "outputs": [], + "source": [ + "log_period = 10\n", + "num_loops = num_iterations // log_period\n", + "\n", + "# Initialize metrics\n", + "metrics = {key: jnp.array([]) for key in [\"iteration\", \"qd_score\", \"coverage\", \"max_fitness\", \"time\"]}\n", + "\n", + "# Set up init metrics\n", + "init_metrics = jax.tree.map(lambda x: jnp.array([x]) if x.shape == () else x, init_metrics)\n", + "init_metrics[\"iteration\"] = jnp.array([0], dtype=jnp.int32)\n", + "init_metrics[\"time\"] = jnp.array([0.0]) # No time recorded for initialization\n", + "\n", + "# Convert init_metrics to match the metrics dictionary structure\n", + "metrics = jax.tree.map(lambda metric, init_metric: jnp.concatenate([metric, init_metric], axis=0), metrics, init_metrics)\n", + "\n", + "# Initialize CSV logger\n", + "csv_logger = CSVLogger(\n", + " \"mapelites-logs.csv\",\n", + " header=list(metrics.keys())\n", + ")\n", + "\n", + "# Log initial metrics\n", + "csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))\n", + "\n", + "# Main loop\n", + "map_elites_scan_update = map_elites.scan_update\n", + "for i in range(num_loops):\n", + " start_time = time.time()\n", + " (\n", + " repertoire,\n", + " emitter_state,\n", + " key,\n", + " ), current_metrics = jax.lax.scan(\n", + " map_elites_scan_update,\n", + " (repertoire, emitter_state, key),\n", + " (),\n", + " length=log_period,\n", + " )\n", + " timelapse = time.time() - start_time\n", + "\n", + " # Metrics\n", + " current_metrics[\"iteration\"] = jnp.arange(1+log_period*i, 1+log_period*(i+1), dtype=jnp.int32)\n", + " current_metrics[\"time\"] = jnp.repeat(timelapse, log_period)\n", + " metrics = jax.tree.map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics, current_metrics)\n", + "\n", + " # Log\n", + " csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KvKKZctyOXZu" + }, + "outputs": [], + "source": [ + "#@title Visualization\n", + "\n", + "# Create the x-axis array\n", + "env_steps = metrics[\"iteration\"]\n", + "\n", + "%matplotlib inline\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Using only part of plot_map_elites_results from QDax, as the repertoire is not 2D ...\n", + "x_label = \"Environment steps\"\n", + "# Customize matplotlib params\n", + "font_size = 16\n", + "params = {\n", + " \"axes.labelsize\": font_size,\n", + " \"axes.titlesize\": font_size,\n", + " \"legend.fontsize\": font_size,\n", + " \"xtick.labelsize\": font_size,\n", + " \"ytick.labelsize\": font_size,\n", + " \"text.usetex\": False,\n", + " \"axes.titlepad\": 10,\n", + "}\n", + "\n", + "mpl.rcParams.update(params)\n", + "\n", + "# Visualize the training evolution and final repertoire\n", + "fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(40, 9))\n", + "\n", + "# env_steps = jnp.arange(num_iterations) * episode_length * batch_size\n", + "\n", + "axes[0].plot(env_steps, metrics[\"coverage\"])\n", + "axes[0].set_xlabel(x_label)\n", + "axes[0].set_ylabel(\"Coverage in %\")\n", + "axes[0].set_title(\"Coverage evolution during training\")\n", + "axes[0].set_aspect(0.95 / axes[0].get_data_ratio(), adjustable=\"box\")\n", + "\n", + "axes[1].plot(env_steps, metrics[\"max_fitness\"])\n", + "axes[1].set_xlabel(x_label)\n", + "axes[1].set_ylabel(\"Maximum fitness\")\n", + "axes[1].set_title(\"Maximum fitness evolution during training\")\n", + "axes[1].set_aspect(0.95 / axes[1].get_data_ratio(), adjustable=\"box\")\n", + "\n", + "axes[2].plot(env_steps, metrics[\"qd_score\"])\n", + "axes[2].set_xlabel(x_label)\n", + "axes[2].set_ylabel(\"QD Score\")\n", + "axes[2].set_title(\"QD Score evolution during training\")\n", + "axes[2].set_aspect(0.95 / axes[2].get_data_ratio(), adjustable=\"box\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Jy6pTtjIOXZu" + }, + "outputs": [], + "source": [ + "best_idx = jnp.argmax(repertoire.fitnesses)\n", + "best_fitness = jnp.max(repertoire.fitnesses)\n", + "best_descriptor = repertoire.descriptors[best_idx]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7q-r05UYOXZv" + }, + "outputs": [], + "source": [ + "print(\n", + " f\"Best fitness in the repertoire: {best_fitness:.2f}\\n\",\n", + " f\"Index in the repertoire of this individual: {best_idx}\\n\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qmhxD-HuOXZv" + }, + "outputs": [], + "source": [ + "# select the parameters of the best individual\n", + "my_params = jax.tree.map(\n", + " lambda x: x[best_idx],\n", + " repertoire.genotypes\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JqcTDMcrOXZv" + }, + "source": [ + "## Play some steps in the environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rmGUxXBNOXZv" + }, + "outputs": [], + "source": [ + "vis_env = environments.create(env_name, episode_length=episode_length, auto_reset=False)\n", + "jit_env_reset = jax.jit(vis_env.reset)\n", + "jit_env_step = jax.jit(vis_env.step)\n", + "jit_inference_fn = jax.jit(policy_network.apply)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PlkfDrxAOXZv" + }, + "outputs": [], + "source": [ + "rollout = []\n", + "key, subkey = jax.random.split(key)\n", + "state = jit_env_reset(rng=subkey)\n", + "rollout.append(state)\n", + "\n", + "while not state.done:\n", + " action = jit_inference_fn(my_params, state.obs)\n", + " state = jit_env_step(state, action)\n", + " rollout.append(state)\n", + "\n", + "print(f\"The trajectory of this individual contains {len(rollout)} transitions.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "njogGypZOULf" + }, + "outputs": [], + "source": [ + "episode = [ts.obs for ts in rollout]\n", + "episode = np.stack(episode)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d1RFYwcyN9Oy" + }, + "source": [ + "At this point there is no built-in way to properly visualise Point-Maze. Here, a simple attempt to show an episode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "olt-s84pOXZv" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "def plot_trajectory(data):\n", + " # Hardcoded boundaries and wall positions\n", + " x_min = -1\n", + " x_max = 1\n", + " y_min = -1\n", + " y_max = 1\n", + " upper_wall_height_offset = 0.2\n", + " lower_wall_height_offset = -0.5\n", + " wallheight = 0.01\n", + " wallwidth = (x_max - x_min) * 0.75\n", + "\n", + " h_zone_width = 0.05\n", + " zone_width_offset = x_min + 0.5\n", + " zone_height_offset = y_max + -0.2\n", + "\n", + " # Upper wall coordinates\n", + " upper_wall_x_1 = [x_min, x_min + wallwidth]\n", + " upper_wall_y_1 = [upper_wall_height_offset, upper_wall_height_offset]\n", + " upper_wall_x_2 = [x_min, x_min + wallwidth]\n", + " upper_wall_y_2 = [upper_wall_height_offset + wallheight, upper_wall_height_offset + wallheight]\n", + "\n", + " # Lower wall coordinates\n", + " lower_wall_x_1 = [x_max - wallwidth, x_max]\n", + " lower_wall_y_1 = [lower_wall_height_offset, lower_wall_height_offset]\n", + " lower_wall_x_2 = [x_max - wallwidth, x_max]\n", + " lower_wall_y_2 = [lower_wall_height_offset + wallheight, lower_wall_height_offset + wallheight]\n", + "\n", + " # Zone\n", + " zone_x_1 = [zone_width_offset - h_zone_width, zone_width_offset + h_zone_width]\n", + " zone_y_1 = [zone_height_offset - h_zone_width, zone_height_offset- h_zone_width]\n", + " zone_x_2 = [zone_width_offset - h_zone_width, zone_width_offset + h_zone_width]\n", + " zone_y_2 = [zone_height_offset + h_zone_width, zone_height_offset + h_zone_width]\n", + " zone_x_3 = [zone_width_offset - h_zone_width, zone_width_offset - h_zone_width]\n", + " zone_y_3 = [zone_height_offset - h_zone_width, zone_height_offset + h_zone_width]\n", + " zone_x_4 = [zone_width_offset + h_zone_width, zone_width_offset + h_zone_width]\n", + " zone_y_4 = [zone_height_offset - h_zone_width, zone_height_offset + h_zone_width]\n", + "\n", + " fig = plt.figure(figsize=(6, 6))\n", + "\n", + " # Plotting the trajectory\n", + " plt.plot(data[:, 0], data[:, 1], marker='o', label='Trajectory')\n", + "\n", + " # Plotting the walls\n", + " plt.plot(upper_wall_x_1, upper_wall_y_1, color='red', label='Wall')\n", + " plt.plot(upper_wall_x_2, upper_wall_y_2, color='red')\n", + " plt.plot(lower_wall_x_1, lower_wall_y_1, color='red')\n", + " plt.plot(lower_wall_x_2, lower_wall_y_2, color='red')\n", + "\n", + " plt.plot(zone_x_1, zone_y_1, color='green', label = \"Goal\")\n", + " plt.plot(zone_x_2, zone_y_2, color='green')\n", + " plt.plot(zone_x_3, zone_y_3, color='green')\n", + " plt.plot(zone_x_4, zone_y_4, color='green')\n", + "\n", + " # Plotting the boundaries\n", + " plt.axhline(y=y_min, color='k', linestyle='--', label='Boundary')\n", + " plt.axhline(y=y_max, color='k', linestyle='--')\n", + " plt.axvline(x=x_min, color='k', linestyle='--')\n", + " plt.axvline(x=x_max, color='k', linestyle='--')\n", + "\n", + " plt.title('Particle Trajectory in 2D Maze')\n", + " plt.xlabel('X Coordinate')\n", + " plt.ylabel('Y Coordinate')\n", + " plt.grid()\n", + " plt.axis('equal')\n", + " plt.legend()\n", + " plt.show()\n", + "\n", + "# Example usage:\n", + "plot_trajectory(episode)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}