Skip to content

Commit 63f7a8b

Browse files
committed
ENH: enable RunsOn with custom ami
1 parent 13abb75 commit 63f7a8b

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

.github/workflows/ci.yml

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,28 @@ name: Build Project [using jupyter-book]
22
on: [pull_request]
33
jobs:
44
preview:
5-
runs-on: quantecon-gpu
6-
container:
7-
image: ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312
8-
options: --gpus all
5+
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2204_ami/disk=large"
96
steps:
107
- uses: actions/checkout@v4
118
with:
129
ref: ${{ github.event.pull_request.head.sha }}
13-
# Check nvidia drivers
14-
- name: nvidia Drivers
10+
- name: Setup Anaconda
11+
uses: conda-incubator/setup-miniconda@v3
12+
with:
13+
auto-update-conda: true
14+
auto-activate-base: true
15+
miniconda-version: 'latest'
16+
python-version: "3.12"
17+
environment-file: environment.yml
18+
activate-environment: quantecon
19+
- name: Install JAX, Numpyro, PyTorch
20+
shell: bash -l {0}
21+
run: |
22+
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
23+
pip install --upgrade "jax[cuda12-local]"
24+
pip install numpyro
25+
python scripts/test-jax-install.py
26+
- name: Check nvidia Drivers
1527
shell: bash -l {0}
1628
run: nvidia-smi
1729
- name: Display Conda Environment Versions

scripts/test-jax-install.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
devices = jax.devices()
5+
print(f"The available devices are: {devices}")
6+
7+
@jax.jit
8+
def matrix_multiply(a, b):
9+
return jnp.dot(a, b)
10+
11+
# Example usage:
12+
key = jax.random.PRNGKey(0)
13+
x = jax.random.normal(key, (1000, 1000))
14+
y = jax.random.normal(key, (1000, 1000))
15+
z = matrix_multiply(x, y)
16+
17+
# Now the function is JIT compiled and will likely run on GPU (if available)
18+
print(z)
19+
20+
devices = jax.devices()
21+
print(f"The available devices are: {devices}")

0 commit comments

Comments
 (0)