Skip to content

Commit f5bcec9

Browse files
authored
feat: add Brax v2 support and restructure brax tasks organization
* New Directory Structure: Created a new /tasks/brax directory with clear versioning: * /tasks/brax/v1/: Contains code compatible with Brax v1 API * /tasks/brax/v2/: Contains code compatible with Brax v2 API * Relocated environment-related code from /environments to /tasks/brax/v1/envs and /tasks/brax/v2/envs * Relocated wrapper code to /tasks/brax/v1/wrappers and /tasks/brax/v2/wrappers * Moved descriptor extractors to /tasks/brax/descriptor_extractors * Added base environment classes in both v1 and v2 paths * Fixed all import statements throughout the codebase to reference the new structure * Updated all relevant test files to use the new imports * Moved AuroraExtraInfo: Relocated AuroraExtraInfo and AuroraExtraInfoNormalization classes from descriptor_extractors to custom_types.py * Stop using task-specific types (e.g. brax.env.State JumanjiState) in the main custom types file
1 parent b5472d0 commit f5bcec9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+2055
-252
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ hydra_outputs/
2828

2929
# Usual python .gitignore
3030

31+
# Mac OS X files
32+
.DS_Store
33+
3134
# Byte-compiled / optimized / DLL files
3235
__pycache__/
3336
*.py[cod]
@@ -166,3 +169,6 @@ dmypy.json
166169

167170
# Cython debug symbols
168171
cython_debug/
172+
173+
# HTML files
174+
*.html

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,16 @@ centroids = compute_euclidean_centroids(
111111
)
112112

113113
# Initializes repertoire and emitter state
114-
repertoire, emitter_state, key = map_elites.init(init_variables, centroids, key)
114+
key, subkey = jax.random.split(key)
115+
repertoire, emitter_state, metrics = map_elites.init(init_variables, centroids, subkey)
115116

116117
# Run MAP-Elites loop
117118
for i in range(num_iterations):
118-
(repertoire, emitter_state, metrics, key,) = map_elites.update(
119+
key, subkey = jax.random.split(key)
120+
(repertoire, emitter_state, metrics,) = map_elites.update(
119121
repertoire,
120122
emitter_state,
121-
key,
123+
subkey,
122124
)
123125

124126
# Get contents of repertoire

docs/api_documentation/environments.md

Lines changed: 0 additions & 3 deletions
This file was deleted.

examples/aurora.ipynb

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,21 +65,20 @@
6565
"\n",
6666
"from qdax.core.aurora import AURORA\n",
6767
"from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire\n",
68-
"from qdax import environments\n",
69-
"from qdax.tasks.brax_envs import (\n",
68+
"import qdax.tasks.brax.v1 as environments\n",
69+
"from qdax.tasks.brax.v1.env_creators import (\n",
7070
" create_default_brax_task_components,\n",
7171
" get_aurora_scoring_fn,\n",
7272
")\n",
73-
"from qdax.environments.descriptor_extractors import (\n",
74-
" AuroraExtraInfoNormalization,\n",
73+
"from qdax.tasks.brax.descriptor_extractors import (\n",
7574
" get_aurora_encoding,\n",
7675
")\n",
7776
"from qdax.core.neuroevolution.buffers.buffer import QDTransition\n",
7877
"from qdax.core.neuroevolution.networks.networks import MLP\n",
7978
"from qdax.core.emitters.mutation_operators import isoline_variation\n",
8079
"from qdax.core.emitters.standard_emitters import MixingEmitter\n",
8180
"\n",
82-
"from qdax.custom_types import Observation\n",
81+
"from qdax.custom_types import AuroraExtraInfoNormalization, Observation\n",
8382
"from qdax.utils import train_seq2seq\n",
8483
"\n",
8584
"\n",

examples/dads.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
"import jax\n",
6060
"import jax.numpy as jnp\n",
6161
"\n",
62-
"from qdax import environments\n",
62+
"import qdax.tasks.brax.v1 as environments\n",
6363
"from qdax.baselines.dads import DADS, DadsConfig, DadsTrainingState\n",
6464
"from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer\n",
6565
"from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn, warmstart_buffer\n",

examples/dcrlme.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,17 @@
6464
"import jax\n",
6565
"import jax.numpy as jnp\n",
6666
"\n",
67-
"from qdax import environments\n",
67+
"import qdax.tasks.brax.v1 as environments\n",
6868
"from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n",
6969
"from qdax.core.emitters.dcrl_me_emitter import DCRLMEConfig, DCRLMEEmitter\n",
7070
"from qdax.core.emitters.mutation_operators import isoline_variation\n",
7171
"from qdax.core.map_elites import MAPElites\n",
7272
"from qdax.core.neuroevolution.buffers.buffer import DCRLTransition\n",
7373
"from qdax.core.neuroevolution.networks.networks import MLP, MLPDC\n",
7474
"from qdax.custom_types import EnvState, Params, RNGKey\n",
75-
"from qdax.environments import descriptor_extractor\n",
76-
"from qdax.environments.wrappers import OffsetRewardWrapper, ClipRewardWrapper\n",
77-
"from qdax.tasks.brax_envs import scoring_function_brax_envs\n",
75+
"from qdax.tasks.brax.v1 import descriptor_extractor\n",
76+
"from qdax.tasks.brax.v1.wrappers.reward_wrappers import OffsetRewardWrapper, ClipRewardWrapper\n",
77+
"from qdax.tasks.brax.v1.env_creators import scoring_function_brax_envs\n",
7878
"from qdax.utils.plotting import plot_map_elites_results\n",
7979
"\n",
8080
"from qdax.utils.metrics import CSVLogger, default_qd_metrics\n",
@@ -101,7 +101,7 @@
101101
"min_descriptor = -30.0 #@param {type:\"number\"}\n",
102102
"max_descriptor = 30.0 #@param {type:\"number\"}\n",
103103
"\n",
104-
"num_iterations = 1000 #@param {type:\"integer\"}\n",
104+
"num_iterations = 200 #@param {type:\"integer\"}\n",
105105
"batch_size = 256 #@param {type:\"integer\"}\n",
106106
"\n",
107107
"# Archive\n",

examples/diayn.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
"import jax\n",
6060
"import jax.numpy as jnp\n",
6161
"\n",
62-
"from qdax import environments\n",
62+
"import qdax.tasks.brax.v1 as environments\n",
6363
"from qdax.baselines.diayn import DIAYN, DiaynConfig, DiaynTrainingState\n",
6464
"from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer\n",
6565
"from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn, warmstart_buffer\n",

examples/distributed_mapelites.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@
7070
"\n",
7171
"from qdax.core.distributed_map_elites import DistributedMAPElites\n",
7272
"from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n",
73-
"from qdax import environments\n",
74-
"from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function\n",
73+
"import qdax.tasks.brax.v1 as environments\n",
74+
"from qdax.tasks.brax.v1.env_creators import scoring_function_brax_envs as scoring_function\n",
7575
"from qdax.core.neuroevolution.buffers.buffer import QDTransition\n",
7676
"from qdax.core.neuroevolution.networks.networks import MLP\n",
7777
"from qdax.core.emitters.mutation_operators import isoline_variation\n",

examples/mapelites.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@
6565
"\n",
6666
"from qdax.core.map_elites import MAPElites\n",
6767
"from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire\n",
68-
"from qdax import environments\n",
69-
"from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function\n",
68+
"import qdax.tasks.brax.v1 as environments\n",
69+
"from qdax.tasks.brax.v1.env_creators import scoring_function_brax_envs as scoring_function\n",
7070
"from qdax.core.neuroevolution.buffers.buffer import QDTransition\n",
7171
"from qdax.core.neuroevolution.networks.networks import MLP\n",
7272
"from qdax.core.emitters.mutation_operators import isoline_variation\n",
@@ -99,7 +99,7 @@
9999
"batch_size = 100 #@param {type:\"number\"}\n",
100100
"env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n",
101101
"episode_length = 100 #@param {type:\"integer\"}\n",
102-
"num_iterations = 1000 #@param {type:\"integer\"}\n",
102+
"num_iterations = 500 #@param {type:\"integer\"}\n",
103103
"seed = 42 #@param {type:\"integer\"}\n",
104104
"policy_hidden_layer_sizes = (64, 64) #@param {type:\"raw\"}\n",
105105
"iso_sigma = 0.005 #@param {type:\"number\"}\n",

examples/mapelites_asktell.ipynb

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@
6464
"import jax.numpy as jnp\n",
6565
"\n",
6666
"from qdax.core.map_elites import MAPElites\n",
67-
"from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire\n",
68-
"from qdax import environments\n",
69-
"from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function\n",
67+
"from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n",
68+
"import qdax.tasks.brax.v1 as environments\n",
69+
"from qdax.tasks.brax.v1.env_creators import scoring_function_brax_envs as scoring_function\n",
7070
"from qdax.core.neuroevolution.buffers.buffer import QDTransition\n",
7171
"from qdax.core.neuroevolution.networks.networks import MLP\n",
7272
"from qdax.core.emitters.mutation_operators import isoline_variation\n",
@@ -75,11 +75,6 @@
7575
"\n",
7676
"from qdax.utils.metrics import CSVLogger, default_qd_metrics\n",
7777
"\n",
78-
"from jax.flatten_util import ravel_pytree\n",
79-
"\n",
80-
"from IPython.display import HTML\n",
81-
"from brax.v1.io import html\n",
82-
"\n",
8378
"\n",
8479
"if \"COLAB_TPU_ADDR\" in os.environ:\n",
8580
" from jax.tools import colab_tpu\n",

0 commit comments

Comments
 (0)