Skip to content

Commit 9fc7736

Browse files
committed
Remove brax v1
1 parent 0c50bba commit 9fc7736

File tree

141 files changed

+726
-3737
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

141 files changed

+726
-3737
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
- name: Set up Python
1717
uses: actions/setup-python@v5
1818
with:
19-
python-version: '3.11'
19+
python-version: "3.11"
2020

2121
- name: Install system dependencies
2222
run: |

examples/aurora.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@
7979
"\n",
8080
"from qdax.core.aurora import AURORA\n",
8181
"from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire\n",
82-
"import qdax.tasks.brax.v1 as environments\n",
83-
"from qdax.tasks.brax.v1.env_creators import (\n",
82+
"import qdax.tasks.brax as environments\n",
83+
"from qdax.tasks.brax.env_creators import (\n",
8484
" create_default_brax_task_components,\n",
8585
" get_aurora_scoring_fn,\n",
8686
")\n",

examples/cmaes.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,12 @@
129129
"metadata": {},
130130
"outputs": [],
131131
"source": [
132-
"def rastrigin_scoring(x: jnp.ndarray):\n",
132+
"def rastrigin_scoring(x: jax.Array):\n",
133133
" first_term = 10 * x.shape[-1]\n",
134134
" second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))\n",
135135
" return -(first_term + second_term)\n",
136136
"\n",
137-
"def sphere_scoring(x: jnp.ndarray):\n",
137+
"def sphere_scoring(x: jax.Array):\n",
138138
" return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1)\n",
139139
"\n",
140140
"if optim_problem == \"sphere\":\n",

examples/cmame.ipynb

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,12 @@
132132
"metadata": {},
133133
"outputs": [],
134134
"source": [
135-
"def rastrigin_scoring(x: jnp.ndarray):\n",
135+
"def rastrigin_scoring(x: jax.Array):\n",
136136
" first_term = 10 * x.shape[-1]\n",
137137
" second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))\n",
138138
" return -(first_term + second_term)\n",
139139
"\n",
140-
"def sphere_scoring(x: jnp.ndarray):\n",
140+
"def sphere_scoring(x: jax.Array):\n",
141141
" return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1)\n",
142142
"\n",
143143
"if optim_problem == \"sphere\":\n",
@@ -147,21 +147,21 @@
147147
"else:\n",
148148
" raise Exception(\"Invalid opt function name given\")\n",
149149
"\n",
150-
"def clip(x: jnp.ndarray):\n",
150+
"def clip(x: jax.Array):\n",
151151
" in_bound = (x <= maxval) * (x >= minval)\n",
152152
" return jnp.where(\n",
153153
" in_bound,\n",
154154
" x,\n",
155155
" (maxval / x)\n",
156156
" )\n",
157157
"\n",
158-
"def _descriptor_1(x: jnp.ndarray):\n",
158+
"def _descriptor_1(x: jax.Array):\n",
159159
" return jnp.sum(clip(x[:x.shape[-1]//2]))\n",
160160
"\n",
161-
"def _descriptor_2(x: jnp.ndarray):\n",
161+
"def _descriptor_2(x: jax.Array):\n",
162162
" return jnp.sum(clip(x[x.shape[-1]//2:]))\n",
163163
"\n",
164-
"def _descriptors(x: jnp.ndarray):\n",
164+
"def _descriptors(x: jax.Array):\n",
165165
" return jnp.array([_descriptor_1(x), _descriptor_2(x)])"
166166
]
167167
},
@@ -198,7 +198,7 @@
198198
"\n",
199199
"num_centroids = math.prod(grid_shape)\n",
200200
"\n",
201-
"def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]:\n",
201+
"def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jax.Array]:\n",
202202
"\n",
203203
" # get metrics\n",
204204
" grid_empty = repertoire.fitnesses == -jnp.inf\n",

examples/cmamega.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,19 @@
124124
"metadata": {},
125125
"outputs": [],
126126
"source": [
127-
"def rastrigin_scoring(x: jnp.ndarray):\n",
127+
"def rastrigin_scoring(x: jax.Array\n",
128128
" return -(10 * x.shape[-1] + jnp.sum((x+minval*0.4)**2 - 10 * jnp.cos(2 * jnp.pi * (x+minval*0.4))))\n",
129129
"\n",
130-
"def clip(x: jnp.ndarray):\n",
130+
"def clip(x: jax.Array\n",
131131
" return x*(x<=maxval)*(x>=+minval) + maxval/x*((x>maxval)+(x<+minval))\n",
132132
"\n",
133-
"def _rastrigin_descriptor_1(x: jnp.ndarray):\n",
133+
"def _rastrigin_descriptor_1(x: jax.Array\n",
134134
" return jnp.mean(clip(x[:x.shape[-1]//2]))\n",
135135
"\n",
136-
"def _rastrigin_descriptor_2(x: jnp.ndarray):\n",
136+
"def _rastrigin_descriptor_2(x: jax.Array\n",
137137
" return jnp.mean(clip(x[x.shape[-1]//2:]))\n",
138138
"\n",
139-
"def rastrigin_descriptors(x: jnp.ndarray):\n",
139+
"def rastrigin_descriptors(x: jax.Array\n",
140140
" return jnp.array([_rastrigin_descriptor_1(x), _rastrigin_descriptor_2(x)])\n",
141141
"\n",
142142
"rastrigin_grad_scores = jax.grad(rastrigin_scoring)"
@@ -191,7 +191,7 @@
191191
"best_objective = rastrigin_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4)\n",
192192
"\n",
193193
"\n",
194-
"def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]:\n",
194+
"def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jax.Array\n",
195195
"\n",
196196
" # get metrics\n",
197197
" grid_empty = repertoire.fitnesses == -jnp.inf\n",

examples/dads.ipynb

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,14 @@
7575
"import jax\n",
7676
"import jax.numpy as jnp\n",
7777
"\n",
78-
"import qdax.tasks.brax.v1 as environments\n",
78+
"import qdax.tasks.brax as environments\n",
7979
"from qdax.baselines.dads import DADS, DadsConfig, DadsTrainingState\n",
8080
"from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer\n",
8181
"from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn, warmstart_buffer\n",
8282
"\n",
8383
"from qdax.utils.plotting import plot_skills_trajectory\n",
8484
"\n",
85-
"from IPython.display import HTML\n",
86-
"from brax.v1.io import html"
85+
"from IPython.display import HTML"
8786
]
8887
},
8988
{
@@ -94,7 +93,7 @@
9493
"\n",
9594
"Most hyperparameters are similar to those introduced in [SAC paper](https://arxiv.org/abs/1801.01290), [DIAYN paper](https://arxiv.org/abs/1802.06070) and [DADS paper](https://arxiv.org/abs/1907.01657).\n",
9695
"\n",
97-
"The parameter `descriptor_full_state` is less straightforward, it concerns the information used for diversity seeking and dynamics. In DADS, one can use the full state for diversity seeking, but one can also use a prior to focus on an interesting aspect of the state. Actually, priors are often used in experiments, for instance, focusing on the x/y position rather than the full position. When `descriptor_full_state` is set to True, it uses the full state, when it is set to False, it uses the 'state descriptor' retrieved by the environment. Hence, it is required that the environment has one. (All the `_uni`, `_omni` do, same for `anttrap`, `antmaze` and `pointmaze`.) In the future, we will add an option to use a prior function directly on the full state."
96+
"The parameter `descriptor_full_state` is less straightforward, it concerns the information used for diversity seeking and dynamics. In DADS, one can use the full state for diversity seeking, but one can also use a prior to focus on an interesting aspect of the state. Actually, priors are often used in experiments, for instance, focusing on the x/y position rather than the full position. When `descriptor_full_state` is set to True, it uses the full state, when it is set to False, it uses the 'state descriptor' retrieved by the environment. Hence, it is required that the environment has one. In the future, we will add an option to use a prior function directly on the full state."
9897
]
9998
},
10099
{
@@ -385,7 +384,7 @@
385384
"source": [
386385
"## Plot the trajectories of the skills at the end of the training\n",
387386
"\n",
388-
"This only works when the state descriptor considered is two-dimensional, and as a real interest only when this state descriptor is the x/y position. Hence, on all \"omni\" tasks, on pointmaze, anttrap and antmaze."
387+
"This only works when the state descriptor considered is two-dimensional, and as a real interest only when this state descriptor is the x/y position."
389388
]
390389
},
391390
{
@@ -419,18 +418,7 @@
419418
"cell_type": "markdown",
420419
"metadata": {},
421420
"source": [
422-
"# Visualize the skills in the physical simulation\n",
423-
"\n",
424-
"WARNING: this does not work with \"pointmaze\""
425-
]
426-
},
427-
{
428-
"cell_type": "code",
429-
"execution_count": null,
430-
"metadata": {},
431-
"outputs": [],
432-
"source": [
433-
"assert env_name != \"pointmaze\", \"No visualisation available for pointmaze at the moment\""
421+
"# Visualize the skills in the physical simulation"
434422
]
435423
},
436424
{

examples/dcrlme.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,17 @@
8080
"import jax\n",
8181
"import jax.numpy as jnp\n",
8282
"\n",
83-
"import qdax.tasks.brax.v1 as environments\n",
83+
"import qdax.tasks.brax as environments\n",
8484
"from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n",
8585
"from qdax.core.emitters.dcrl_me_emitter import DCRLMEConfig, DCRLMEEmitter\n",
8686
"from qdax.core.emitters.mutation_operators import isoline_variation\n",
8787
"from qdax.core.map_elites import MAPElites\n",
8888
"from qdax.core.neuroevolution.buffers.buffer import DCRLTransition\n",
8989
"from qdax.core.neuroevolution.networks.networks import MLP, MLPDC\n",
9090
"from qdax.custom_types import EnvState, Params, RNGKey\n",
91-
"from qdax.tasks.brax.v1 import descriptor_extractor\n",
92-
"from qdax.tasks.brax.v1.wrappers.reward_wrappers import OffsetRewardWrapper, ClipRewardWrapper\n",
93-
"from qdax.tasks.brax.v1.env_creators import scoring_function_brax_envs\n",
91+
"from qdax.tasks.brax import descriptor_extractor\n",
92+
"from qdax.tasks.brax.wrappers.reward_wrappers import OffsetRewardWrapper, ClipRewardWrapper\n",
93+
"from qdax.tasks.brax.env_creators import scoring_function_brax_envs\n",
9494
"from qdax.utils.plotting import plot_map_elites_results\n",
9595
"\n",
9696
"from qdax.utils.metrics import CSVLogger, default_qd_metrics"
@@ -436,7 +436,7 @@
436436
],
437437
"metadata": {
438438
"kernelspec": {
439-
"display_name": "venv",
439+
"display_name": ".venv",
440440
"language": "python",
441441
"name": "python3"
442442
},
@@ -450,7 +450,7 @@
450450
"name": "python",
451451
"nbconvert_exporter": "python",
452452
"pygments_lexer": "ipython3",
453-
"version": "3.10.12"
453+
"version": "3.13.9"
454454
}
455455
},
456456
"nbformat": 4,

examples/diayn.ipynb

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,14 @@
7575
"import jax\n",
7676
"import jax.numpy as jnp\n",
7777
"\n",
78-
"import qdax.tasks.brax.v1 as environments\n",
78+
"import qdax.tasks.brax as environments\n",
7979
"from qdax.baselines.diayn import DIAYN, DiaynConfig, DiaynTrainingState\n",
8080
"from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer\n",
8181
"from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn, warmstart_buffer\n",
8282
"\n",
8383
"from qdax.utils.plotting import plot_skills_trajectory\n",
8484
"\n",
85-
"from IPython.display import HTML\n",
86-
"from brax.v1.io import html"
85+
"from IPython.display import HTML"
8786
]
8887
},
8988
{
@@ -94,7 +93,7 @@
9493
"\n",
9594
"Most hyperparameters are similar to those introduced in [SAC paper](https://arxiv.org/abs/1801.01290) and [DIAYN paper](https://arxiv.org/abs/1802.06070).\n",
9695
"\n",
97-
"The parameter `descriptor_full_state` is less straightforward, it concerns the information used for diversity seeking and discrimination. In DIAYN, one can use the full state for diversity seeking, but one can also use a prior to focus on an interesting aspect of the state. Actually, priors are often used in experiments, for instance, focusing on the x/y position rather than the full position. When `descriptor_full_state` is set to True, it uses the full state, when it is set to False, it uses the 'state descriptor' retrieved by the environment. Hence, it is required that the environment has one. (All the `_uni`, `_omni` do, same for `anttrap`, `antmaze` and `pointmaze`.) In the future, we will add an option to use a prior function directly on the full state."
96+
"The parameter `descriptor_full_state` is less straightforward, it concerns the information used for diversity seeking and discrimination. In DIAYN, one can use the full state for diversity seeking, but one can also use a prior to focus on an interesting aspect of the state. Actually, priors are often used in experiments, for instance, focusing on the x/y position rather than the full position. When `descriptor_full_state` is set to True, it uses the full state, when it is set to False, it uses the 'state descriptor' retrieved by the environment. Hence, it is required that the environment has one. In the future, we will add an option to use a prior function directly on the full state."
9897
]
9998
},
10099
{
@@ -374,7 +373,7 @@
374373
"source": [
375374
"## Plot the trajectories of the skills at the end of the training\n",
376375
"\n",
377-
"This only works when the state descriptor considered is two-dimensional, and as a real interest only when this state descriptor is the x/y position. Hence, on all \"omni\" tasks, on pointmaze, anttrap and antmaze."
376+
"This only works when the state descriptor considered is two-dimensional, and as a real interest only when this state descriptor is the x/y position."
378377
]
379378
},
380379
{
@@ -408,18 +407,7 @@
408407
"cell_type": "markdown",
409408
"metadata": {},
410409
"source": [
411-
"# Visualize the skills in the physical simulation\n",
412-
"\n",
413-
"WARNING: this does not work with \"pointmaze\""
414-
]
415-
},
416-
{
417-
"cell_type": "code",
418-
"execution_count": null,
419-
"metadata": {},
420-
"outputs": [],
421-
"source": [
422-
"assert env_name != \"pointmaze\", \"No visualisation available for pointmaze at the moment\""
410+
"# Visualize the skills in the physical simulation"
423411
]
424412
},
425413
{

examples/distributed_mapelites.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@
8181
"\n",
8282
"from qdax.core.distributed_map_elites import DistributedMAPElites\n",
8383
"from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n",
84-
"import qdax.tasks.brax.v1 as environments\n",
85-
"from qdax.tasks.brax.v1.env_creators import scoring_function_brax_envs as scoring_function\n",
84+
"import qdax.tasks.brax as environments\n",
85+
"from qdax.tasks.brax.env_creators import scoring_function_brax_envs as scoring_function\n",
8686
"from qdax.core.neuroevolution.buffers.buffer import QDTransition\n",
8787
"from qdax.core.neuroevolution.networks.networks import MLP\n",
8888
"from qdax.core.emitters.mutation_operators import isoline_variation\n",

examples/jumanji_snake.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@
284284
"outputs": [],
285285
"source": [
286286
"# Prepare the scoring function\n",
287-
"def descriptor_extraction(data: QDTransition, mask: jnp.ndarray, linear_projection: jnp.ndarray) -> Descriptor:\n",
287+
"def descriptor_extraction(data: QDTransition, mask: jax.Arraylinear_projection: jajax.Array Descriptor:\n",
288288
" \"\"\"Compute feet contact time proportion.\n",
289289
"\n",
290290
" This function suppose that state descriptor is the feet contact, as it\n",
@@ -340,7 +340,7 @@
340340
"outputs": [],
341341
"source": [
342342
"def scoring_function(\n",
343-
" genotypes: jnp.ndarray, key: RNGKey\n",
343+
" genotypes: jax.Arraykey: RNGKey\n",
344344
") -> Tuple[Fitness, ExtraScores, RNGKey]:\n",
345345
" fitnesses, _, extra_scores = scoring_fn(genotypes, key)\n",
346346
" return fitnesses.reshape(-1, 1), extra_scores"

0 commit comments

Comments
 (0)