File tree Expand file tree Collapse file tree 2 files changed +39
-6
lines changed
Expand file tree Collapse file tree 2 files changed +39
-6
lines changed Original file line number Diff line number Diff line change @@ -2,16 +2,28 @@ name: Build Project [using jupyter-book]
22on : [pull_request]
33jobs :
44 preview :
5- runs-on : quantecon-gpu
6- container :
7- image : ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312
8- options : --gpus all
5+ runs-on : " runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2204_ami/disk=large"
96 steps :
107 - uses : actions/checkout@v4
118 with :
129 ref : ${{ github.event.pull_request.head.sha }}
13- # Check nvidia drivers
14- - name : nvidia Drivers
10+ - name : Setup Anaconda
11+ uses : conda-incubator/setup-miniconda@v3
12+ with :
13+ auto-update-conda : true
14+ auto-activate-base : true
15+ miniconda-version : ' latest'
16+ python-version : " 3.12"
17+ environment-file : environment.yml
18+ activate-environment : quantecon
19+ - name : Install JAX, Numpyro, PyTorch
20+ shell : bash -l {0}
21+ run : |
22+ pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
23+ pip install --upgrade "jax[cuda12-local]"
24+ pip install numpyro
25+ python scripts/test-jax-install.py
26+ - name : Check nvidia Drivers
1527 shell : bash -l {0}
1628 run : nvidia-smi
1729 - name : Display Conda Environment Versions
Original file line number Diff line number Diff line change 1+ import jax
2+ import jax .numpy as jnp
3+
4+ devices = jax .devices ()
5+ print (f"The available devices are: { devices } " )
6+
7+ @jax .jit
8+ def matrix_multiply (a , b ):
9+ return jnp .dot (a , b )
10+
11+ # Example usage:
12+ key = jax .random .PRNGKey (0 )
13+ x = jax .random .normal (key , (1000 , 1000 ))
14+ y = jax .random .normal (key , (1000 , 1000 ))
15+ z = matrix_multiply (x , y )
16+
17+ # Now the function is JIT compiled and will likely run on GPU (if available)
18+ print (z )
19+
20+ devices = jax .devices ()
21+ print (f"The available devices are: { devices } " )
You can’t perform that action at this time.
0 commit comments