Skip to content

Commit 9e5fe03

Browse files
committed
update installation setup
1 parent 4417069 commit 9e5fe03

File tree

3 files changed

+53
-42
lines changed

3 files changed

+53
-42
lines changed

brainpy/__init__.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,27 @@
1010
raise ModuleNotFoundError(
1111
'''
1212
13-
Please install jaxlib. See
14-
15-
https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax
13+
BrainPy needs jaxlib, please install jaxlib.
1614
17-
for installation instructions.
15+
1. If you are using Windows system, install jaxlib through
16+
17+
>>> pip install jaxlib -f https://whls.blob.core.windows.net/unstable/index.html
18+
19+
2. If you are using macOS platform, install jaxlib through
20+
21+
>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html
22+
23+
3. If you are using Linux platform, install jaxlib through
24+
25+
>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html
26+
27+
4. If you are using Linux + CUDA platform, install jaxlib through
28+
29+
>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
30+
31+
Note that the versions of "jax" and "jaxlib" should be consistent, like "jax=0.3.14", "jaxlib=0.3.14".
32+
33+
More detail installation instruction, please see https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax
1834
1935
''') from None
2036

docs/quickstart/installation.rst

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,18 @@ Linux & MacOS
8989
^^^^^^^^^^^^^
9090

9191
Currently, JAX supports **Linux** (Ubuntu 16.04 or later) and **macOS** (10.12 or
92-
later) platforms. The provided binary releases of JAX for Linux and macOS
92+
later) platforms. The provided binary releases of `jax` and `jaxlib` for Linux and macOS
9393
systems are available at
9494

9595
- for CPU: https://storage.googleapis.com/jax-releases/jax_releases.html
9696
- for GPU: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
9797

9898

99-
To install a CPU-only version of JAX, you can run
99+
If you want to install a CPU-only version of `jax` and `jaxlib`, you can run
100100

101101
.. code-block:: bash
102102
103-
pip install --upgrade "jax[cpu]"
103+
pip install --upgrade "jax[cpu]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
104104
105105
If you want to install JAX with both CPU and NVidia GPU support, you must first install
106106
`CUDA`_ and `CuDNN`_, if they have not already been installed. Next, run
@@ -109,42 +109,57 @@ If you want to install JAX with both CPU and NVidia GPU support, you must first
109109
110110
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
111111
112-
Alternatively, you can download the preferred release ".whl" file for jaxlib, and install it via ``pip``:
112+
113+
Alternatively, you can download the preferred release ".whl" file for jaxlib
114+
from the above release links, and install it via ``pip``:
113115

114116
.. code-block:: bash
115117
116118
pip install xxx-0.3.14-xxx.whl
117119
118120
pip install jax==0.3.14
119121
120-
Note that the versions of `jaxlib` and `jax` should be consistent.
122+
.. note::
123+
124+
Note that the versions of `jaxlib` and `jax` should be consistent.
125+
126+
For example, if you are using `jax==0.3.14`, you would better install `jax==0.3.14`.
127+
121128

122129

123130
Windows
124131
^^^^^^^
125132

126-
For **Windows** users, JAX can be installed by the following methods:
133+
For **Windows** users, `jax` and `jaxlib` can be installed from the community supports.
134+
Specifically, you can install `jax` and `jaxlib` through:
135+
136+
.. code-block:: bash
137+
138+
pip install "jax[cpu]" -f https://whls.blob.core.windows.net/unstable/index.html
139+
140+
If you are using GPU, you can install GPU-versioned wheels through:
141+
142+
.. code-block:: bash
127143
128-
- **Method 1**: There are several communities support JAX for Windows, please refer
129-
to the github link for more details: https://github.com/cloudhan/jax-windows-builder .
130-
Simply speaking, the provided binary releases of JAX for Windows
131-
are available at https://whls.blob.core.windows.net/unstable/index.html .
144+
pip install "jax[cuda111]" -f https://whls.blob.core.windows.net/unstable/index.html
132145
133-
You can download the preferred release ".whl" file, and install it via ``pip``:
146+
Alternatively, you can manually install you favourite version of `jax` and `jaxlib` by
147+
downloading binary releases of JAX for Windows from https://whls.blob.core.windows.net/unstable/index.html .
148+
Then install it via ``pip``:
134149

135150
.. code-block:: bash
136151
137152
pip install xxx-0.3.14-xxx.whl
138153
139154
pip install jax==0.3.14
140155
141-
- **Method 2**: For Windows 10+ system, you can use `Windows Subsystem for Linux (WSL)`_.
142-
The installation guide can be found in `WSL Installation Guide for Windows 10`_.
143-
Then, you can install JAX in WSL just like the installation step in Linux/MacOs.
144-
145-
146-
- **Method 3**: You can also `build JAX from source`_.
156+
WSL
157+
^^^
147158

159+
Moreover, for Windows 10+ system, we recommend using `Windows Subsystem for Linux (WSL)`_.
160+
The installation guide can be found in
161+
`WSL Installation Guide for Windows 10/11 <https://docs.microsoft.com/en-us/windows/wsl/install-win10>`_.
162+
Then, you can install JAX in WSL just like the installation step in Linux/MacOs.
148163

149164

150165
Dependency 3: brainpylib
@@ -194,7 +209,6 @@ packages:
194209
.. _Matplotlib: https://matplotlib.org/
195210
.. _JAX: https://github.com/google/jax
196211
.. _Windows Subsystem for Linux (WSL): https://docs.microsoft.com/en-us/windows/wsl/about
197-
.. _WSL Installation Guide for Windows 10: https://docs.microsoft.com/en-us/windows/wsl/install-win10
198212
.. _build JAX from source: https://jax.readthedocs.io/en/latest/developer.html
199213
.. _SymPy: https://github.com/sympy/sympy
200214
.. _Numba: https://numba.pydata.org/

setup.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,6 @@
3737
with io.open(os.path.join(here, 'README.md'), 'r', encoding='utf-8') as f:
3838
README = f.read()
3939

40-
# require users to install jaxlib before installing brainpy on Windows platform
41-
requirements = ['numpy>=1.15', 'jax>=0.3.0', 'tqdm']
42-
if sys.platform.startswith('win32') or sys.platform.startswith('cygwin'):
43-
try:
44-
import jaxlib
45-
except ModuleNotFoundError:
46-
raise ModuleNotFoundError('''
47-
48-
----------------------------------------------------------------------
49-
We detect that your are using Windows platform.
50-
Please manually install "jaxlib" before installing "brainpy".
51-
See https://whls.blob.core.windows.net/unstable/index.html
52-
for jaxlib's Windows wheels.
53-
----------------------------------------------------------------------
54-
55-
''') from None
56-
else:
57-
requirements.append('jaxlib>=0.3.0')
58-
5940
# installation packages
6041
packages = find_packages()
6142
if 'docs' in packages:
@@ -74,7 +55,7 @@
7455
author_email='[email protected]',
7556
packages=packages,
7657
python_requires='>=3.7',
77-
install_requires=requirements,
58+
install_requires=['numpy>=1.15', 'jax>=0.3.0', 'tqdm'],
7859
url='https://github.com/PKU-NIP-Lab/BrainPy',
7960
project_urls={
8061
"Bug Tracker": "https://github.com/PKU-NIP-Lab/BrainPy/issues",

0 commit comments

Comments
 (0)