Skip to content

Commit e86c7db

Browse files
committed
Fix qdax installation in notebooks
1 parent 29cffb1 commit e86c7db

26 files changed

+821
-459
lines changed

examples/aurora.ipynb

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,43 @@
2424
"- how to visualise the optimization process"
2525
]
2626
},
27+
{
28+
"cell_type": "markdown",
29+
"metadata": {},
30+
"source": [
31+
"## Installation"
32+
]
33+
},
34+
{
35+
"cell_type": "markdown",
36+
"metadata": {},
37+
"source": [
38+
"You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:"
39+
]
40+
},
2741
{
2842
"cell_type": "code",
2943
"execution_count": null,
3044
"metadata": {},
3145
"outputs": [],
3246
"source": [
33-
"from IPython.display import clear_output\n",
34-
"\n",
35-
"try:\n",
36-
" import qdax\n",
37-
"except:\n",
38-
" print(\"QDax not found. Installing...\")\n",
39-
" !pip install qdax[cuda12]\n",
40-
" import qdax\n",
41-
"\n",
42-
"clear_output()"
47+
"%pip install -U \"jax[cuda]\""
48+
]
49+
},
50+
{
51+
"cell_type": "markdown",
52+
"metadata": {},
53+
"source": [
54+
"Then, install QDax from PyPI:"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"metadata": {},
61+
"outputs": [],
62+
"source": [
63+
"%pip install -U \"qdax[examples]\""
4364
]
4465
},
4566
{
@@ -48,14 +69,8 @@
4869
"metadata": {},
4970
"outputs": [],
5071
"source": [
51-
"!pip install ipympl | tail -n 1\n",
52-
"# %matplotlib widget\n",
53-
"# from google.colab import output\n",
54-
"# output.enable_custom_widget_manager()\n",
55-
"\n",
5672
"import os\n",
5773
"\n",
58-
"from IPython.display import clear_output\n",
5974
"import functools\n",
6075
"from typing import Dict, Any\n",
6176
"\n",
@@ -78,15 +93,7 @@
7893
"from qdax.core.emitters.standard_emitters import MixingEmitter\n",
7994
"\n",
8095
"from qdax.custom_types import AuroraExtraInfoNormalization, Observation\n",
81-
"from qdax.utils import train_seq2seq\n",
82-
"\n",
83-
"\n",
84-
"if \"COLAB_TPU_ADDR\" in os.environ:\n",
85-
" from jax.tools import colab_tpu\n",
86-
" colab_tpu.setup_tpu()\n",
87-
"\n",
88-
"\n",
89-
"clear_output()"
96+
"from qdax.utils import train_seq2seq"
9097
]
9198
},
9299
{

examples/cmaes.ipynb

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,48 @@
2323
"- how to visualise the optimization process"
2424
]
2525
},
26+
{
27+
"cell_type": "markdown",
28+
"id": "931313f8",
29+
"metadata": {},
30+
"source": [
31+
"## Installation"
32+
]
33+
},
34+
{
35+
"cell_type": "markdown",
36+
"id": "d43f4db2",
37+
"metadata": {},
38+
"source": [
39+
"You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:"
40+
]
41+
},
2642
{
2743
"cell_type": "code",
2844
"execution_count": null,
29-
"id": "2",
45+
"id": "ad841f6f",
3046
"metadata": {},
3147
"outputs": [],
3248
"source": [
33-
"from IPython.display import clear_output\n",
34-
"\n",
35-
"try:\n",
36-
" import qdax\n",
37-
"except:\n",
38-
" print(\"QDax not found. Installing...\")\n",
39-
" !pip install qdax[cuda12]\n",
40-
" import qdax\n",
41-
"\n",
42-
"clear_output()"
49+
"%pip install -U \"jax[cuda]\""
50+
]
51+
},
52+
{
53+
"cell_type": "markdown",
54+
"id": "366885ad",
55+
"metadata": {},
56+
"source": [
57+
"Then, install QDax from PyPI:"
58+
]
59+
},
60+
{
61+
"cell_type": "code",
62+
"execution_count": null,
63+
"id": "07a2fbdc",
64+
"metadata": {},
65+
"outputs": [],
66+
"source": [
67+
"%pip install -U \"qdax[examples]\""
4368
]
4469
},
4570
{

examples/cmame.ipynb

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,43 @@
2323
"- how to visualise the optimization process"
2424
]
2525
},
26+
{
27+
"cell_type": "markdown",
28+
"metadata": {},
29+
"source": [
30+
"## Installation"
31+
]
32+
},
33+
{
34+
"cell_type": "markdown",
35+
"metadata": {},
36+
"source": [
37+
"You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:"
38+
]
39+
},
2640
{
2741
"cell_type": "code",
2842
"execution_count": null,
2943
"metadata": {},
3044
"outputs": [],
3145
"source": [
32-
"from IPython.display import clear_output\n",
33-
"\n",
34-
"try:\n",
35-
" import qdax\n",
36-
"except:\n",
37-
" print(\"QDax not found. Installing...\")\n",
38-
" !pip install qdax[cuda12]\n",
39-
" import qdax\n",
40-
"\n",
41-
"clear_output()"
46+
"%pip install -U \"jax[cuda]\""
47+
]
48+
},
49+
{
50+
"cell_type": "markdown",
51+
"metadata": {},
52+
"source": [
53+
"Then, install QDax from PyPI:"
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"metadata": {},
60+
"outputs": [],
61+
"source": [
62+
"%pip install -U \"qdax[examples]\""
4263
]
4364
},
4465
{

examples/cmamega.ipynb

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,43 @@
2323
"- how to visualise the optimization process"
2424
]
2525
},
26+
{
27+
"cell_type": "markdown",
28+
"metadata": {},
29+
"source": [
30+
"## Installation"
31+
]
32+
},
33+
{
34+
"cell_type": "markdown",
35+
"metadata": {},
36+
"source": [
37+
"You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:"
38+
]
39+
},
2640
{
2741
"cell_type": "code",
2842
"execution_count": null,
2943
"metadata": {},
3044
"outputs": [],
3145
"source": [
32-
"from IPython.display import clear_output\n",
33-
"\n",
34-
"try:\n",
35-
" import qdax\n",
36-
"except:\n",
37-
" print(\"QDax not found. Installing...\")\n",
38-
" !pip install qdax[cuda12]\n",
39-
" import qdax\n",
40-
"\n",
41-
"clear_output()"
46+
"%pip install -U \"jax[cuda]\""
47+
]
48+
},
49+
{
50+
"cell_type": "markdown",
51+
"metadata": {},
52+
"source": [
53+
"Then, install QDax from PyPI:"
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"metadata": {},
60+
"outputs": [],
61+
"source": [
62+
"%pip install -U \"qdax[examples]\""
4263
]
4364
},
4465
{

examples/dads.ipynb

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,43 @@
2222
"- how to visualise the final trajectories learned"
2323
]
2424
},
25+
{
26+
"cell_type": "markdown",
27+
"metadata": {},
28+
"source": [
29+
"## Installation"
30+
]
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"metadata": {},
35+
"source": [
36+
"You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:"
37+
]
38+
},
2539
{
2640
"cell_type": "code",
2741
"execution_count": null,
2842
"metadata": {},
2943
"outputs": [],
3044
"source": [
31-
"from IPython.display import clear_output\n",
32-
"\n",
33-
"try:\n",
34-
" import qdax\n",
35-
"except:\n",
36-
" print(\"QDax not found. Installing...\")\n",
37-
" !pip install qdax[cuda12]\n",
38-
" import qdax\n",
39-
"\n",
40-
"clear_output()"
45+
"%pip install -U \"jax[cuda]\""
46+
]
47+
},
48+
{
49+
"cell_type": "markdown",
50+
"metadata": {},
51+
"source": [
52+
"Then, install QDax from PyPI:"
53+
]
54+
},
55+
{
56+
"cell_type": "code",
57+
"execution_count": null,
58+
"metadata": {},
59+
"outputs": [],
60+
"source": [
61+
"%pip install -U \"qdax[examples]\""
4162
]
4263
},
4364
{
@@ -46,11 +67,6 @@
4667
"metadata": {},
4768
"outputs": [],
4869
"source": [
49-
"!pip install ipympl | tail -n 1\n",
50-
"# %matplotlib widget\n",
51-
"# from google.colab import output\n",
52-
"# output.enable_custom_widget_manager()\n",
53-
"\n",
5470
"import os\n",
5571
"\n",
5672
"from IPython.display import clear_output\n",
@@ -67,16 +83,7 @@
6783
"from qdax.utils.plotting import plot_skills_trajectory\n",
6884
"\n",
6985
"from IPython.display import HTML\n",
70-
"from brax.v1.io import html\n",
71-
"\n",
72-
"\n",
73-
"\n",
74-
"if \"COLAB_TPU_ADDR\" in os.environ:\n",
75-
" from jax.tools import colab_tpu\n",
76-
" colab_tpu.setup_tpu()\n",
77-
"\n",
78-
"\n",
79-
"clear_output()"
86+
"from brax.v1.io import html"
8087
]
8188
},
8289
{

0 commit comments

Comments
 (0)