Skip to content

Commit f76e84f

Browse files
committed
Merge branch 'main' into develop
2 parents 56ad87e + f06dda3 commit f76e84f

File tree

6 files changed

+27
-21
lines changed

6 files changed

+27
-21
lines changed

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,14 @@ centroids = compute_euclidean_centroids(
114114
key, subkey = jax.random.split(key)
115115
repertoire, emitter_state, metrics = map_elites.init(init_variables, centroids, subkey)
116116

117+
# Jit the update function for faster iterations
118+
update_fn = jax.jit(map_elites.update)
119+
117120
# Run MAP-Elites loop
118121
for i in range(num_iterations):
119122
key, subkey = jax.random.split(key)
120-
(repertoire, emitter_state, metrics,) = map_elites.update(
123+
124+
(repertoire, emitter_state, metrics,) = update_fn(
121125
repertoire,
122126
emitter_state,
123127
subkey,
@@ -146,7 +150,7 @@ QDax currently supports the following algorithms:
146150
| [Multi-Objective MAP-Elites (MOME)](https://arxiv.org/abs/2202.03057) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mome.ipynb) |
147151
| [MAP-Elites Evolution Strategies (MEES)](https://dl.acm.org/doi/pdf/10.1145/3377930.3390217) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mees.ipynb) |
148152
| [MAP-Elites PBT (ME-PBT)](https://openreview.net/forum?id=CBfYffLqWqb) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/me_sac_pbt.ipynb) |
149-
| [MAP-Elites Low-Spread (ME-LS)](https://dl.acm.org/doi/abs/10.1145/3583131.3590433) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/me_ls.ipynb) |
153+
| [MAP-Elites Low-Spread (ME-LS)](https://dl.acm.org/doi/abs/10.1145/3583131.3590433) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mels.ipynb) |
150154

151155

152156
## QDax baseline algorithms
@@ -200,6 +204,9 @@ QDax was developed and is maintained by the [Adaptive & Intelligent Robotics Lab
200204
<a href="https://github.com/maxiallard" title="Maxime Allard"><img src="https://github.com/maxiallard.png" height="auto" width="50" style="border-radius:50%"></a>
201205
<a href="https://github.com/Lookatator" title="Luca Grilloti"><img src="https://github.com/Lookatator.png" height="auto" width="50" style="border-radius:50%"></a>
202206
<a href="https://github.com/manon-but-yes" title="Manon Flageat"><img src="https://github.com/manon-but-yes.png" height="auto" width="50" style="border-radius:50%"></a>
207+
<a href="https://github.com/maxencefaldor" title="Maxence Faldor"><img src="https://github.com/maxencefaldor.png" height="auto" width="50" style="border-radius:50%"></a>
208+
<a href="https://github.com/hannah-jan" title="Hannah Janmohamed"><img src="https://github.com/hannah-jan.png" height="auto" width="50" style="border-radius:50%"></a>
209+
<a href="https://github.com/LisaCoiffard" title="Lisa Coiffard"><img src="https://github.com/LisaCoiffard.png" height="auto" width="50" style="border-radius:50%"></a>
203210
<a href="https://github.com/Aneoshun" title="Antoine Cully"><img src="https://github.com/Aneoshun.png" height="auto" width="50" style="border-radius:50%"></a>
204211
<a href="https://github.com/felixchalumeau" title="Felix Chalumeau"><img src="https://github.com/felixchalumeau.png" height="auto" width="50" style="border-radius:50%"></a>
205212
<a href="https://github.com/ranzenTom" title="Thomas Pierrot"><img src="https://github.com/ranzenTom.png" height="auto" width="50" style="border-radius:50%"></a>

examples/aurora.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@
370370
" model, subkey, (1, *observations_dims)\n",
371371
")\n",
372372
"\n",
373-
"print(jax.tree_map(lambda x: x.shape, model_params))\n",
373+
"print(jax.tree.map(lambda x: x.shape, model_params))\n",
374374
"\n",
375375
"# Define the encoder function\n",
376376
"encoder_fn = jax.jit(\n",

examples/me_sac_pbt.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,10 @@
236236
"\n",
237237
"def scoring_function(genotypes, key):\n",
238238
" population_size = jax.tree.leaves(genotypes)[0].shape[0]\n",
239-
" first_states = jax.tree_map(\n",
239+
" first_states = jax.tree.map(\n",
240240
" lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states\n",
241241
" )\n",
242-
" first_states = jax.tree_map(\n",
242+
" first_states = jax.tree.map(\n",
243243
" lambda x: jnp.repeat(x, population_size, axis=0), first_states\n",
244244
" )\n",
245245
" population_returns, population_descriptors, _, _ = eval_policy(genotypes, first_states)\n",
@@ -378,7 +378,7 @@
378378
" repertoire, emitter_state, metrics = update_fn(\n",
379379
" repertoire, emitter_state, keys\n",
380380
" )\n",
381-
" metrics_cpu = jax.tree_map(\n",
381+
" metrics_cpu = jax.tree.map(\n",
382382
" lambda x: jax.device_put(x, jax.devices(\"cpu\")[0])[0], metrics\n",
383383
" )\n",
384384
" timelapse = time.time() - start_time\n",
@@ -401,7 +401,7 @@
401401
"outputs": [],
402402
"source": [
403403
"# Create the performance evolution plots and visualize final grid\n",
404-
"repertoire_cpu = jax.tree_map(\n",
404+
"repertoire_cpu = jax.tree.map(\n",
405405
" lambda x: jax.device_put(x, jax.devices(\"cpu\")[0])[0], repertoire\n",
406406
")\n",
407407
"num_loops_with_init = num_loops + 1\n",
@@ -510,7 +510,7 @@
510510
"key, subkey = jax.random.split(key)\n",
511511
"env_state = jax.jit(env.reset)(rng=subkey)\n",
512512
"\n",
513-
"training_state, env_state = jax.tree_map(\n",
513+
"training_state, env_state = jax.tree.map(\n",
514514
" lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)\n",
515515
")\n",
516516
"\n",
@@ -529,7 +529,7 @@
529529
"outputs": [],
530530
"source": [
531531
"rollout = [\n",
532-
" jax.tree_map(lambda x: jax.device_put(x[0], jax.devices(\"cpu\")[0]), env_state)\n",
532+
" jax.tree.map(lambda x: jax.device_put(x[0], jax.devices(\"cpu\")[0]), env_state)\n",
533533
" for env_state in rollout\n",
534534
"]"
535535
]

examples/me_td3_pbt.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,10 @@
238238
"\n",
239239
"def scoring_function(genotypes, key):\n",
240240
" population_size = jax.tree_leaves(genotypes)[0].shape[0]\n",
241-
" first_states = jax.tree_map(\n",
241+
" first_states = jax.tree.map(\n",
242242
" lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states\n",
243243
" )\n",
244-
" first_states = jax.tree_map(\n",
244+
" first_states = jax.tree.map(\n",
245245
" lambda x: jnp.repeat(x, population_size, axis=0), first_states\n",
246246
" )\n",
247247
" population_returns, population_descriptors, _, _ = eval_policy(genotypes, first_states)\n",
@@ -379,7 +379,7 @@
379379
" repertoire, emitter_state, metrics = update_fn(\n",
380380
" repertoire, emitter_state, keys\n",
381381
" )\n",
382-
" metrics_cpu = jax.tree_map(lambda x: jax.device_get(x)[0], metrics)\n",
382+
" metrics_cpu = jax.tree.map(lambda x: jax.device_get(x)[0], metrics)\n",
383383
" timelapse = time.time() - start_time\n",
384384
"\n",
385385
" # log metrics\n",
@@ -391,7 +391,7 @@
391391
" all_metrics[k] = v\n",
392392
"\n",
393393
" if i % save_repertoire_freq == 0:\n",
394-
" repertoires.append(jax.tree_map(lambda x: jax.device_get(x)[0], repertoire))"
394+
" repertoires.append(jax.tree.map(lambda x: jax.device_get(x)[0], repertoire))"
395395
]
396396
},
397397
{

qdax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.4.1"
1+
__version__ = "0.5.0"

qdax/tasks/brax/v2/wrappers/eval_metrics_wrapper.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@ def reset(self, rng: jnp.ndarray) -> State:
2222
reset_state = self.env.reset(rng)
2323
reset_state.metrics["reward"] = reset_state.reward
2424
eval_metrics = CompletedEvalMetrics(
25-
current_episode_metrics=jax.tree_util.tree_map(
26-
jnp.zeros_like, reset_state.metrics
27-
),
28-
completed_episodes_metrics=jax.tree_util.tree_map(
25+
current_episode_metrics=jax.tree.map(jnp.zeros_like, reset_state.metrics),
26+
completed_episodes_metrics=jax.tree.map(
2927
lambda x: jnp.zeros_like(jnp.sum(x)), reset_state.metrics
3028
),
3129
completed_episodes=jnp.zeros(()),
@@ -46,16 +44,17 @@ def step(self, state: State, action: jnp.ndarray) -> State:
4644
completed_episodes_steps = state_metrics.completed_episodes_steps + jnp.sum(
4745
nstate.info["steps"] * nstate.done
4846
)
49-
current_episode_metrics = jax.tree_util.tree_map(
47+
48+
current_episode_metrics = jax.tree.map(
5049
lambda a, b: a + b, state_metrics.current_episode_metrics, nstate.metrics
5150
)
5251
completed_episodes = state_metrics.completed_episodes + jnp.sum(nstate.done)
53-
completed_episodes_metrics = jax.tree_util.tree_map(
52+
completed_episodes_metrics = jax.tree.map(
5453
lambda a, b: a + jnp.sum(b * nstate.done),
5554
state_metrics.completed_episodes_metrics,
5655
current_episode_metrics,
5756
)
58-
current_episode_metrics = jax.tree_util.tree_map(
57+
current_episode_metrics = jax.tree.map(
5958
lambda a, b: a * (1 - nstate.done) + b * nstate.done,
6059
current_episode_metrics,
6160
nstate.metrics,

0 commit comments

Comments
 (0)