Skip to content

Commit ee3b103

Browse files
authored
APEBench and its dependencies + Updates to equinox module for supporting cuda as well as latest versions (spack#1662)
1 parent 5ece166 commit ee3b103

File tree

5 files changed

+216
-11
lines changed

5 files changed

+216
-11
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright Spack Project Developers. See COPYRIGHT file for details.
2+
#
3+
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
4+
5+
6+
from spack_repo.builtin.build_systems.cuda import CudaPackage
7+
from spack_repo.builtin.build_systems.python import PythonPackage
8+
9+
from spack.package import *
10+
11+
12+
class PyApebench(PythonPackage, CudaPackage):
13+
"""APEBench is a JAX-based tool to evaluate autoregressive neural emulators for PDEs
14+
on periodic domains in 1d, 2d, and 3d. It comes with an efficient reference simulator
15+
based on spectral methods that is used for procedural data generation (no need to download
16+
large datasets with APEBench). Since this simulator can also be embedded into emulator training
17+
(e.g., for a "solver-in-the-loop" correction setting), this is the first benchmark suite to
18+
support differentiable physics."""
19+
20+
homepage = "https://tum-pbs.github.io/apebench-paper/"
21+
pypi = "apebench/apebench-0.1.1.tar.gz"
22+
23+
maintainers("abhishek1297")
24+
license("MIT", checked_by="abhishek1297")
25+
26+
version("0.1.1", sha256="c5ddd47799f0799b2c2e72c27d3d81993f6fa218a04b1df93d4c1850e4893bf9")
27+
28+
depends_on("c", type="build")
29+
depends_on("cxx", type="build")
30+
depends_on("py-setuptools", type="build")
31+
depends_on("[email protected]:3.12", type=("build", "run"))
32+
33+
with default_args(type="run"):
34+
for arch in CudaPackage.cuda_arch_values:
35+
cuda_specs = f"+cuda cuda_arch={arch}"
36+
with when(cuda_specs):
37+
depends_on(f"py-jaxlib {cuda_specs}")
38+
depends_on(f"[email protected]: {cuda_specs}")
39+
depends_on(f"[email protected] {cuda_specs}")
40+
depends_on(f"[email protected] {cuda_specs}")
41+
depends_on(f"[email protected] {cuda_specs}")
42+
43+
depends_on("[email protected]:")
44+
depends_on("[email protected]:")
45+
depends_on("[email protected]:")
46+
depends_on("[email protected]:")
47+
depends_on("[email protected]:")
48+
depends_on("[email protected]:")
49+
depends_on("[email protected]:")
50+
depends_on("[email protected]:")
51+
52+
with when("~cuda"):
53+
depends_on("[email protected]:")
54+
depends_on("[email protected]")
55+
depends_on("[email protected]")
56+
depends_on("[email protected]")
57+
58+
def setup_run_environment(self, env):
59+
if "+cuda" in self.spec:
60+
cuda_home = self.spec["cuda"].prefix
61+
# This is an irrelevant lib path and it is purely used by NVIDIA profilers.
62+
# But, since JAX throws RuntimeError on it, we set this path.
63+
env.prepend_path("LD_LIBRARY_PATH", f"{cuda_home}/extras/CUPTI/lib64")

repos/spack_repo/builtin/packages/py_equinox/package.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,60 @@
22
#
33
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
44

5+
from spack_repo.builtin.build_systems.cuda import CudaPackage
56
from spack_repo.builtin.build_systems.python import PythonPackage
67

78
from spack.package import *
89

910

10-
class PyEquinox(PythonPackage):
11-
"""Equinox is a comprehensive JAX library that provides a wide
12-
range of tools and features not found in core JAX, including neural networks
13-
with PyTorch-like syntax, filtered APIs for transformations, PyTree manipulation
14-
routines, and advanced features like runtime errors."""
11+
class PyEquinox(PythonPackage, CudaPackage):
12+
"""Equinox is your one-stop [JAX](https://github.com/google/jax) library,
13+
for everything you need that isn't already in core JAX:
14+
- neural networks (or more generally any model), with easy-to-use PyTorch-like syntax;
15+
- filtered APIs for transformations;
16+
- useful PyTree manipulation routines;
17+
- advanced features like runtime errors;
1518
16-
homepage = "https://docs.kidger.site/equinox/"
19+
and best of all, Equinox isn't a framework: everything you write in Equinox is compatible
20+
with anything else in JAX or the ecosystem."""
1721

18-
pypi = "equinox/equinox-0.11.10.tar.gz"
22+
homepage = "https://docs.kidger.site/equinox/"
23+
pypi = "equinox/equinox-0.13.1.tar.gz"
1924

20-
license("Apache-2.0", checked_by="viperML")
25+
maintainers("abhishek1297")
26+
license("Apache-2.0", checked_by="abhishek1297")
2127

28+
version("0.13.1", sha256="e90f11cfe66b2f73f5c172260a17c48851794a0f243dd2cbe4ea70f4c90cbd07")
29+
version("0.13.0", sha256="d59615be722373e9d66e0ba78462964e6357fb76a8b1b98c2c6027961b778a69")
30+
version("0.12.2", sha256="648e4206bbc53b228922e8f18cd3cffe543ddda1172c0002f8954e484bab0023")
31+
version("0.12.1", sha256="7ed4b84553cb59d4930185f87ac2c1121aab2b38999be9499c021e7583a7ed0d")
32+
version("0.12.0", sha256="6a99877376cfc168cfe44220a734740926bf23eb9c0cd0d7fdc49adfec4d78ca")
2233
version("0.11.12", sha256="bee22aabaf7ee0cde6f2ae58cf3b981dea73d47e297361a0203e299208ef1739")
34+
version("0.11.11", sha256="648072c1384adc3528930a3bf089246fd77aa873310a19f1f21c08e7681f95a7")
2335
version("0.11.10", sha256="f3e7d5545b71e427859a28050526d09adb6b20285c47476a606328a0b96c9509")
36+
version("0.11.9", sha256="e0f0fa5ea597949492d201ab4d08b05c2d5b4020c65a1778aedf6ad76c2c4fe7")
37+
version("0.11.8", sha256="d1e91a05e41bb9538db72a8e15d26daf958348c26714533434c88c5ec0c0b0ef")
38+
version("0.11.7", sha256="96e0216a9d822ec4b1465b0cbfbab14a36fb7e7d62c55f521287db3aaaa251be")
39+
version("0.11.6", sha256="e237c25e446960ed479f086df240d4dd779bb0917bafc76811d341ccac76b712")
40+
version("0.11.5", sha256="5e0ca252eeb20cc5eece225d3d35137c7e57f998a1c6422a1972db9e5c68b7f6")
41+
version("0.11.4", sha256="0033d9731083f402a855b12a0777a80aa8507651f7aa86d9f0f9503bcddfd320")
42+
version("0.11.3", sha256="a1273cc28c60d3131ac596f8a0f5c7dd384729e6cddae86e7be05f026880e8e0")
2443

2544
depends_on("py-hatchling", type="build")
45+
depends_on("[email protected]:3.12", type=("build", "run"))
46+
47+
with default_args(type="run"):
48+
for arch in CudaPackage.cuda_arch_values:
49+
cuda_specs = f"+cuda cuda_arch={arch}"
50+
with when(cuda_specs):
51+
depends_on(f"[email protected]:0.4.26 {cuda_specs}", when="@:0.11.10")
52+
depends_on(f"[email protected]:0.5 {cuda_specs}", when="@0.11.11:0.11.12")
53+
depends_on(f"[email protected]:0.6 {cuda_specs}", when="@0.12:")
2654

27-
with default_args(type=("build", "run")):
28-
depends_on("[email protected]:")
2955
depends_on("[email protected]:0.4.26", when="@:0.11.10")
30-
depends_on("[email protected]:", when="@0.11.11:")
56+
depends_on("[email protected]:0.5", when="@0.11.11:0.11.12")
57+
depends_on("[email protected]:0.6", when="@0.12:")
58+
3159
depends_on("[email protected]:")
3260
depends_on("[email protected]:")
3361
depends_on("[email protected]:")
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright Spack Project Developers. See COPYRIGHT file for details.
2+
#
3+
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
4+
5+
6+
from spack_repo.builtin.build_systems.cuda import CudaPackage
7+
from spack_repo.builtin.build_systems.python import PythonPackage
8+
9+
from spack.package import *
10+
11+
12+
class PyExponax(PythonPackage, CudaPackage):
13+
"""Efficient Differentiable n-d PDE solvers built on top of JAX & Equinox."""
14+
15+
homepage = "https://fkoehler.site/exponax/"
16+
pypi = "exponax/exponax-0.1.0.tar.gz"
17+
18+
maintainers("abhishek1297")
19+
license("MIT", checked_by="abhishek1297")
20+
21+
version("0.1.0", sha256="25acdb5c1b76f5706316750a3133f427f0faec441a1ffe3b90697d5f32abb5e7")
22+
version("0.0.1", sha256="e2a201752d38dbfd233d52c2f59ed0dc344ccbb3e796b26c2713c6a2357d7366")
23+
24+
depends_on("py-setuptools", type="build")
25+
depends_on("[email protected]:3.12", type=("build", "run"))
26+
27+
with default_args(type="run"):
28+
for arch in CudaPackage.cuda_arch_values:
29+
cuda_specs = f"+cuda cuda_arch={arch}"
30+
depends_on(f"[email protected]: {cuda_specs}", when=f"{cuda_specs}")
31+
32+
depends_on("[email protected]:")
33+
depends_on("[email protected]:")
34+
depends_on("[email protected]:")
35+
depends_on("[email protected]:")
36+
depends_on("[email protected]:")
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright Spack Project Developers. See COPYRIGHT file for details.
2+
#
3+
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
4+
5+
6+
from spack_repo.builtin.build_systems.cuda import CudaPackage
7+
from spack_repo.builtin.build_systems.python import PythonPackage
8+
9+
from spack.package import *
10+
11+
12+
class PyPdequinox(PythonPackage, CudaPackage):
13+
"""A collection of neural architectures for emulating Partial Differential Equations (PDEs)
14+
in JAX agnostic to the spatial dimension (1D, 2D, 3D) and boundary conditions
15+
(Dirichlet, Neumann, Periodic). This package is built on top of Equinox."""
16+
17+
homepage = "https://fkoehler.site/pdequinox/"
18+
pypi = "pdequinox/pdequinox-0.1.2.tar.gz"
19+
20+
maintainers("abhishek1297")
21+
license("MIT", checked_by="abhishek1297")
22+
23+
version("0.1.2", sha256="7ee9dcbf277cbb94cda508034c0955600a03bc4c664bede5eb61b4a4b99b54c5")
24+
version("0.1.0", sha256="07f7516fe26823e6c3b71f1ed5a170e97cc34ff1d1349435d4b7469adc540d3a")
25+
26+
depends_on("py-setuptools", type="build")
27+
depends_on("[email protected]:3.12", type=("build", "run"))
28+
29+
with default_args(type="run"):
30+
for arch in CudaPackage.cuda_arch_values:
31+
cuda_specs = f"+cuda cuda_arch={arch}"
32+
depends_on(f"[email protected]: {cuda_specs}", when=f"{cuda_specs}")
33+
34+
depends_on("[email protected]:")
35+
depends_on("[email protected]:")
36+
depends_on("[email protected]:")
37+
depends_on("[email protected]:")
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright Spack Project Developers. See COPYRIGHT file for details.
2+
#
3+
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
4+
5+
6+
from spack_repo.builtin.build_systems.cuda import CudaPackage
7+
from spack_repo.builtin.build_systems.python import PythonPackage
8+
9+
from spack.package import *
10+
11+
12+
class PyTrainax(PythonPackage, CudaPackage):
13+
"""Convenience abstractions using optax to train neural networks to autoregressively
14+
emulate time-dependent problems taking care of trajectory subsampling and offering a wide
15+
range of training methodologies (regarding unrolling length and including
16+
differentiable physics).
17+
"""
18+
19+
homepage = "https://fkoehler.site/trainax/"
20+
pypi = "trainax/trainax-0.0.2.tar.gz"
21+
22+
maintainers("abhishek1297")
23+
license("MIT", checked_by="abhishek1297")
24+
25+
version("0.0.2", sha256="3c7eeeb94e351db7ff0b036b1c1fb6f78ddc25ab72d6c1afe69547cbefa70ca8")
26+
version("0.0.1", sha256="19552dfca2d6f9d7e69963e978628adb19dc2ba9cb9563b510c19e136116c23a")
27+
28+
depends_on("py-setuptools", type="build")
29+
depends_on("[email protected]:3.12", type=("build", "run"))
30+
31+
with default_args(type="run"):
32+
for arch in CudaPackage.cuda_arch_values:
33+
cuda_specs = f"+cuda cuda_arch={arch}"
34+
depends_on(f"[email protected]: {cuda_specs}", when=f"{cuda_specs}")
35+
36+
depends_on("[email protected]:")
37+
depends_on("[email protected]:")
38+
depends_on("[email protected]:")
39+
depends_on("[email protected]:")
40+
depends_on("[email protected]:")
41+
depends_on("[email protected]:")

0 commit comments

Comments
 (0)