Skip to content

Commit ecf7fde

Browse files
Add B200 testing to continuous workflow
1 parent 97bbc37 commit ecf7fde

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

.github/workflows/pytest_cuda.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ jobs:
5454
runs-on: ${{ inputs.runner }}
5555
# TODO: Update to the generic ML ecosystem test containers when they are ready.
5656
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') ||
57-
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') }}
57+
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') ||
58+
(contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') }}
5859
name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
5960

6061
env:

.github/workflows/wheel_tests_continuous.yml

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,18 +110,30 @@ jobs:
110110
fail-fast: false # don't cancel all jobs on failure
111111
matrix:
112112
# Python values need to match the matrix stategy in the artifact build jobs above
113-
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"]
113+
# See exlusions for what is fully tested
114+
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"]
114115
python: ["3.10",]
115-
cuda: ["12.3", "12.1"]
116+
cuda: ["12.1","12.3","12.8"]
116117
enable-x64: [1, 0]
117118
exclude:
118-
# Run only a single configuration on H100 to save resources
119+
# L4 does not run on cuda 12.8 but tests other configs
120+
- runner: "linux-x86-g2-48-l4-4gpu"
121+
cuda: "12.8"
122+
# H100 runs only a single config, CUDA 12.3 Enable x64 1
123+
- runner: "linux-x86-a3-8g-h100-8gpu"
124+
cuda: "12.8"
119125
- runner: "linux-x86-a3-8g-h100-8gpu"
120-
python: "3.10"
121126
cuda: "12.1"
122127
- runner: "linux-x86-a3-8g-h100-8gpu"
123-
python: "3.10"
124-
enable-x64: 0
128+
enable-x64: "0"
129+
# B200 runs only a single config, CUDA 12.8 Enable x64 1
130+
- runner: "linux-x86-a4-224-b200-1gpu"
131+
enable-x64: "0"
132+
- runner: "linux-x86-a4-224-b200-1gpu"
133+
cuda: "12.1"
134+
- runner: "linux-x86-a4-224-b200-1gpu"
135+
cuda: "12.3"
136+
125137
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})"
126138
with:
127139
runner: ${{ matrix.runner }}

0 commit comments

Comments
 (0)