File tree Expand file tree Collapse file tree 5 files changed +76
-18
lines changed Expand file tree Collapse file tree 5 files changed +76
-18
lines changed Original file line number Diff line number Diff line change 66 workflow_dispatch :
77jobs :
88 cache :
9- runs-on : quantecon-gpu
10- container :
11- image : ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312
12- options : --gpus all
9+ runs-on : " runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
1310 steps :
1411 - uses : actions/checkout@v4
1512 with :
1613 ref : ${{ github.event.pull_request.head.sha }}
14+ - name : Setup Anaconda
15+ uses : conda-incubator/setup-miniconda@v3
16+ with :
17+ auto-update-conda : true
18+ auto-activate-base : true
19+ miniconda-version : ' latest'
20+ python-version : " 3.12"
21+ environment-file : environment.yml
22+ activate-environment : quantecon
23+ - name : Install JAX, Numpyro, PyTorch
24+ shell : bash -l {0}
25+ run : |
26+ pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
27+ pip install --upgrade "jax[cuda12-local]"
28+ pip install numpyro
29+ python scripts/test-jax-install.py
1730 - name : Check nvidia drivers
1831 shell : bash -l {0}
1932 run : |
Original file line number Diff line number Diff line change 11name : Build Project [using jupyter-book]
2- on : [pull_request]
2+ on :
3+ pull_request :
4+ workflow_dispatch :
35jobs :
46 preview :
5- runs-on : quantecon-gpu
6- container :
7- image : ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312
8- options : --gpus all
7+ runs-on : " runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
98 steps :
109 - uses : actions/checkout@v4
1110 with :
1211 ref : ${{ github.event.pull_request.head.sha }}
13- # Check nvidia drivers
14- - name : nvidia Drivers
12+ - name : Setup Anaconda
13+ uses : conda-incubator/setup-miniconda@v3
14+ with :
15+ auto-update-conda : true
16+ auto-activate-base : true
17+ miniconda-version : ' latest'
18+ python-version : " 3.12"
19+ environment-file : environment.yml
20+ activate-environment : quantecon
21+ - name : Install JAX, Numpyro, PyTorch
22+ shell : bash -l {0}
23+ run : |
24+ pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
25+ pip install --upgrade "jax[cuda12-local]"
26+ pip install numpyro
27+ python scripts/test-jax-install.py
28+ - name : Check nvidia Drivers
1529 shell : bash -l {0}
1630 run : nvidia-smi
1731 - name : Display Conda Environment Versions
Original file line number Diff line number Diff line change @@ -2,7 +2,7 @@ name: Build Project on Google Collab (Execution)
22on : [pull_request]
33jobs :
44 execution-checks :
5- runs-on : quantecon- gpu
5+ runs-on : " runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=ubuntu24- gpu-x64/disk=large "
66 container :
77 image : docker://us-docker.pkg.dev/colab-images/public/runtime
88 options : --gpus all
Original file line number Diff line number Diff line change 66jobs :
77 publish :
88 if : github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
9- runs-on : quantecon-gpu
10- container :
11- image : ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312
12- options : --gpus all
9+ runs-on : " runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
1310 steps :
1411 - name : Checkout
1512 uses : actions/checkout@v4
16- - name : Install Git (required to commit notebooks)
13+ - name : Setup Anaconda
14+ uses : conda-incubator/setup-miniconda@v3
15+ with :
16+ auto-update-conda : true
17+ auto-activate-base : true
18+ miniconda-version : ' latest'
19+ python-version : " 3.12"
20+ environment-file : environment.yml
21+ activate-environment : quantecon
22+ - name : Install JAX, Numpyro, PyTorch
1723 shell : bash -l {0}
18- run : apt-get install -y git
24+ run : |
25+ pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
26+ pip install --upgrade "jax[cuda12-local]"
27+ pip install numpyro
28+ python scripts/test-jax-install.py
1929 - name : Check nvidia drivers
2030 shell : bash -l {0}
2131 run : |
Original file line number Diff line number Diff line change 1+ import jax
2+ import jax .numpy as jnp
3+
4+ devices = jax .devices ()
5+ print (f"The available devices are: { devices } " )
6+
7+ @jax .jit
8+ def matrix_multiply (a , b ):
9+ return jnp .dot (a , b )
10+
11+ # Example usage:
12+ key = jax .random .PRNGKey (0 )
13+ x = jax .random .normal (key , (1000 , 1000 ))
14+ y = jax .random .normal (key , (1000 , 1000 ))
15+ z = matrix_multiply (x , y )
16+
17+ # Now the function is JIT compiled and will likely run on GPU (if available)
18+ print (z )
19+
20+ devices = jax .devices ()
21+ print (f"The available devices are: { devices } " )
You can’t perform that action at this time.
0 commit comments