Skip to content

Commit b97b75f

Browse files
committed
Allow for cuda support in pyproject
1 parent 7d382ca commit b97b75f

File tree

3 files changed

+217
-43
lines changed

3 files changed

+217
-43
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<h1 align="center">Doob’s Lagrangian: A Sample-Efficient Variational Approach to Transition Path Sampling</h1>
22
<p align="center">
33
<a href="https://github.com/plainerman/variational-doob"><img src="https://img.shields.io/badge/python-3670A0?style=for-the-badge&logo=python&logoColor=ffdd54" alt="Built with Python"/></a>
4-
<a href="https://github.com/plainerman/Variational-Doob/blob/main/notebooks/tps_gaussian.ipynb"><img src="https://img.shields.io/badge/jupyter-%23FA0F00.svg?style=for-the-badge&logo=jupyter&logoColor=white" alt="Jupyter"/></a>
4+
<a href="https://github.com/plainerman/variational-doob/blob/main/notebooks/tps_gaussian.ipynb"><img src="https://img.shields.io/badge/jupyter-%23FA0F00.svg?style=for-the-badge&logo=jupyter&logoColor=white" alt="Jupyter"/></a>
55
<a href="https://github.com/jax-ml/jax"><img src="https://img.shields.io/badge/library-JAX-5f0964?style=for-the-badge" alt="Jax"/></a>
66
</p>
77
<p align="center">

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ dependencies = [
1212
"scipy==1.12.0",
1313
"scikit-image>=0.24.0",
1414
"openpathsampling>=1.6.1",
15-
"jax==0.4.23",
15+
"jax==0.4.23; platform_system == 'Darwin'",
16+
"jax[cuda12]==0.4.23; platform_system != 'Darwin'",
1617
"jaxlib==0.4.23",
1718
"flax==0.8.3",
1819
"notebook>=7.2.2",

0 commit comments

Comments
 (0)