|
24 | 24 | "- how to visualise the optimization process" |
25 | 25 | ] |
26 | 26 | }, |
| 27 | + { |
| 28 | + "cell_type": "markdown", |
| 29 | + "metadata": {}, |
| 30 | + "source": [ |
| 31 | + "## Installation" |
| 32 | + ] |
| 33 | + }, |
| 34 | + { |
| 35 | + "cell_type": "markdown", |
| 36 | + "metadata": {}, |
| 37 | + "source": [ |
| 38 | + "You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:" |
| 39 | + ] |
| 40 | + }, |
27 | 41 | { |
28 | 42 | "cell_type": "code", |
29 | 43 | "execution_count": null, |
30 | 44 | "metadata": {}, |
31 | 45 | "outputs": [], |
32 | 46 | "source": [ |
33 | | - "from IPython.display import clear_output\n", |
34 | | - "\n", |
35 | | - "try:\n", |
36 | | - " import qdax\n", |
37 | | - "except:\n", |
38 | | - " print(\"QDax not found. Installing...\")\n", |
39 | | - " !pip install qdax[cuda12]\n", |
40 | | - " import qdax\n", |
41 | | - "\n", |
42 | | - "clear_output()" |
| 47 | + "%pip install -U \"jax[cuda]\"" |
| 48 | + ] |
| 49 | + }, |
| 50 | + { |
| 51 | + "cell_type": "markdown", |
| 52 | + "metadata": {}, |
| 53 | + "source": [ |
| 54 | + "Then, install QDax from PyPI:" |
| 55 | + ] |
| 56 | + }, |
| 57 | + { |
| 58 | + "cell_type": "code", |
| 59 | + "execution_count": null, |
| 60 | + "metadata": {}, |
| 61 | + "outputs": [], |
| 62 | + "source": [ |
| 63 | + "%pip install -U \"qdax[examples]\"" |
43 | 64 | ] |
44 | 65 | }, |
45 | 66 | { |
|
48 | 69 | "metadata": {}, |
49 | 70 | "outputs": [], |
50 | 71 | "source": [ |
51 | | - "!pip install ipympl | tail -n 1\n", |
52 | | - "# %matplotlib widget\n", |
53 | | - "# from google.colab import output\n", |
54 | | - "# output.enable_custom_widget_manager()\n", |
55 | | - "\n", |
56 | 72 | "import os\n", |
57 | 73 | "\n", |
58 | | - "from IPython.display import clear_output\n", |
59 | 74 | "import functools\n", |
60 | 75 | "from typing import Dict, Any\n", |
61 | 76 | "\n", |
|
78 | 93 | "from qdax.core.emitters.standard_emitters import MixingEmitter\n", |
79 | 94 | "\n", |
80 | 95 | "from qdax.custom_types import AuroraExtraInfoNormalization, Observation\n", |
81 | | - "from qdax.utils import train_seq2seq\n", |
82 | | - "\n", |
83 | | - "\n", |
84 | | - "if \"COLAB_TPU_ADDR\" in os.environ:\n", |
85 | | - " from jax.tools import colab_tpu\n", |
86 | | - " colab_tpu.setup_tpu()\n", |
87 | | - "\n", |
88 | | - "\n", |
89 | | - "clear_output()" |
| 96 | + "from qdax.utils import train_seq2seq" |
90 | 97 | ] |
91 | 98 | }, |
92 | 99 | { |
|
0 commit comments