Skip to content

Running TPU tests on linux-x86-ct6e-44-1tpu #21425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 65 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
1d7c685
added requirements-tensorflow-tpu.txt and tpu configuration in .kokoro
kharshith-k Jun 16, 2025
19b5e6b
updated .kokoro/github/ubuntu/tpu/build.sh with jax and torch backend…
kharshith-k Jun 16, 2025
d203ca3
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 18, 2025
f45e5d0
Changed the tpu CI config files path to .github from .kokoro
kharshith-k Jun 18, 2025
6771cc0
Added new job in .github/workflows/actions.yml to run TPU tests
kharshith-k Jun 18, 2025
87d36e7
fixed runs-on option in acvtions.yml for tpu_build job to run on self…
kharshith-k Jun 18, 2025
9901298
Added another runner in the actions TPU job
kharshith-k Jun 18, 2025
be97210
Update continuous.cfg
kharshith-k Jun 18, 2025
a1cd5c3
Update presubmit.cfg
kharshith-k Jun 18, 2025
c5e3a5c
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 23, 2025
f0ab676
Update actions.yml
kharshith-k Jun 23, 2025
09161d7
Developed Dockerfile for TPU build job in actions.yml
kharshith-k Jun 24, 2025
9a3948f
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 24, 2025
058fdff
Update actions.yml
kharshith-k Jun 24, 2025
d47e39e
Included few more runners in tpu_build job
kharshith-k Jun 26, 2025
a6a59d7
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 26, 2025
ba4f6ae
Using linux-x86-ct6e-44-1tpu
kharshith-k Jun 26, 2025
a5a3624
Modified requirement-commmon.txt and updated requirements-tensorflow-…
kharshith-k Jun 30, 2025
b9998af
Added Dtypes_TPU_tests.py and requirements-jax-tpu.txt
kharshith-k Jul 22, 2025
f68be97
Progress bar now handles `steps_per_execution`. (#21422)
hertschuh Jun 26, 2025
1018abf
Fix symbolic call of `logsumexp` with int axis. (#21428)
hertschuh Jun 27, 2025
0da77e4
Only allow deserialization of `KerasSaveable`s by module and name. (#…
hertschuh Jun 29, 2025
cb639c5
commented tensorflow deps
kharshith-k Jul 2, 2025
c0d1743
Added log of dtypes_test_tpu.py and the test script for the same
kharshith-k Jul 2, 2025
306e6e7
modified dtypes_test_tpu.py as per pre-commit standards
kharshith-k Jul 2, 2025
4e584fc
Added TPU initiaization and teardown functionalities in conftest.py, …
kharshith-k Jul 3, 2025
bb09e95
Added dtypes_test_TPU.py and dtypes_new_test.py, modified conftest.py
kharshith-k Jul 9, 2025
8a63d09
Added Dcokerfile and tests list command
kharshith-k Jul 23, 2025
4651454
Updated Dockerfile
kharshith-k Jul 28, 2025
40af241
Restored Dockerfile to previous changes
kharshith-k Jul 28, 2025
64420d5
updated actions.yml file to install and configure docker engine on se…
kharshith-k Jul 28, 2025
da84de5
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jul 28, 2025
d69277d
updated actions.yml file to include container option
kharshith-k Jul 28, 2025
1c307fc
updated actions.yml file to include container option without volume b…
kharshith-k Jul 28, 2025
693886b
updated actions.yml file to change TPU
kharshith-k Jul 28, 2025
e74b851
Updated container path in build-and-test-on-tpu job
kharshith-k Jul 29, 2025
d31b3c4
seperated TPU workflow from actions.yml
kharshith-k Jul 29, 2025
a70d19e
updated trigger condition for TPU tests workflow
kharshith-k Jul 29, 2025
5f5b609
updated container usage configuration for TPU tests workflow
kharshith-k Jul 29, 2025
72e729f
updated env vars for TPU tests workflow
kharshith-k Jul 29, 2025
e129299
updated env vars parsing syntax in TPU tests workflow
kharshith-k Jul 29, 2025
3fe5b57
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
10df307
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
dd21e09
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
328628f
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
01f0c17
updated image name in TPU tests workflow
kharshith-k Jul 29, 2025
3e41c37
updated image name with generic ubuntu image
kharshith-k Jul 29, 2025
5e55c2c
updated tpu-tests to use ghcr
kharshith-k Jul 29, 2025
ea9ff88
updated tpu-tests to store built image as local tar
kharshith-k Jul 29, 2025
6d92aa9
updated image name from ubuntu:22.04 to docker:24.0-cli in tpu tests …
kharshith-k Jul 29, 2025
3c75bf8
updated image name from docker:24.0-cli to ubuntu:22.04 in tpu tests…
kharshith-k Jul 29, 2025
1589a75
added volume mount from host in load-and-test-job
kharshith-k Jul 29, 2025
36bd682
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jul 29, 2025
04112cf
Reverted tpu-tests.yml to version using ghcr.io for image storage
kharshith-k Jul 29, 2025
87d7ad8
Removed custom dtypes_test files for TPU testing and restored origina…
kharshith-k Aug 12, 2025
6cb097c
Updated tpu-tests.yml to pull image from GCP artifact registry
kharshith-k Aug 12, 2025
4829f1b
Resolved conflicts in actions.yml
kharshith-k Aug 12, 2025
a2eb306
Added a workflow to check service accounts associated with self hoste…
kharshith-k Aug 12, 2025
23579c4
Made find_sa.yml specific to linux-x86-ct6e-44-1tpu
kharshith-k Aug 12, 2025
dac6433
Added container tag to find_sa.yml
kharshith-k Aug 12, 2025
05461c1
Checking SA for linux-x86-ct5lp-112-4tpu
kharshith-k Aug 12, 2025
078dcee
Checking SA for linux-x86-ct6e-44-1tpu-nxgm7-runner-vb87c
kharshith-k Aug 12, 2025
016c68d
Using SA for auth in tpu-tests
kharshith-k Aug 12, 2025
02657f0
Updated SA with container tag for auth in tpu-tests
kharshith-k Aug 12, 2025
7167952
Added docker socket mount test
kharshith-k Aug 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
run: |
pip install -r requirements.txt --progress-bar off --upgrade
if [ "${{ matrix.nnx_enabled }}" == "true" ]; then
pip install --upgrade git+https://github.com/google/flax.git
pip install --upgrade flax>=0.11.1
fi
pip uninstall -y keras keras-nightly
pip install -e "." --progress-bar off --upgrade
Expand Down Expand Up @@ -147,4 +147,4 @@ jobs:
pip uninstall -y keras keras-nightly
pip install -e "." --progress-bar off --upgrade
- name: Run pre-commit
run: pre-commit run --all-files --hook-stage manual
run: pre-commit run --all-files --hook-stage manual
24 changes: 24 additions & 0 deletions .github/workflows/docker-socket-mount-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Test Docker Socket Mount

on:
push:
branches: [ master ]
pull_request:

jobs:
socket-mount-test:
name: Docker Socket Mount Test
runs-on: linux-x86-ct6e-44-1tpu

# Use a minimal public container. The key is the 'options' block.
container:
image: ubuntu:latest
options: -v /var/run/docker.sock:/var/run/docker.sock

steps:
- name: Check if docker.sock is mounted
run: |
echo "Attempting to list the Docker socket inside the container..."
ls -l /var/run/docker.sock
echo "If the line above shows the file details, the mount was successful."

31 changes: 31 additions & 0 deletions .github/workflows/find_sa.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Find Runner Service Account

# Trigger on push to master, pull requests, or new releases
on:
push:
branches: [ master ]
pull_request:
release:
types: [created]

jobs:
discover_service_account:
name: Discover Service Account
# The runner name is now hardcoded since we removed the manual input
runs-on: linux-x86-ct6e-44-1tpu-nxgm7-runner-vb87c

# Add a container block to satisfy the runner's security policy.
# We use a standard Google Cloud SDK image which has gcloud pre-installed.
container:
image: gcr.io/google.com/cloudsdktool/cloud-sdk:slim
# Using host networking makes it easier for the container to access
# the host VM's metadata server to find the service account.
options: --network host

steps:
- name: Print the configured GCP Service Account
run: |
echo "Querying the GCP metadata server from within a container..."
echo "The active account shown below is the one used by the runner's host VM."
gcloud auth list

58 changes: 58 additions & 0 deletions .github/workflows/tpu-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
name: TPU Tests

on:
push:
branches: [ master ]
pull_request:
release:
types: [created]

permissions:
contents: read

jobs:
test-on-tpu-runner:
name: Test on TPU Runner
runs-on: linux-x86-ct6e-44-1tpu

# STAGE 1: Use a public Google Cloud SDK image for the main job container.
# This satisfies the runner's "container required" rule without needing prior authentication.
container:
image: gcr.io/google.com/cloudsdktool/cloud-sdk:slim
# Mount the host's Docker socket into this container. This allows steps inside
# to use the host's Docker daemon, which is critical.
# Privileged and network host are needed for the final test container to access TPU hardware.
options: --privileged --network host -v /var/run/docker.sock:/var/run/docker.sock

steps:
- name: Authenticate to Google Cloud
uses: google-github-actions/auth@v2
with:
# This authenticates the environment inside the gcloud-sdk container.
credentials_json: ${{ secrets.GCP_CREDENTIALS }}

- name: Checkout Repository
uses: actions/checkout@v4

- name: Run Tests Inside Pre-Built Container
# STAGE 2: Now that we are authenticated, we can use the Docker CLI
# (talking to the host's daemon via the mounted socket) to run our private container.
run: |
echo "Pulling the private test container from GAR..."
docker pull us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest

echo "Running the test container..."
docker run --rm --privileged --network host \
-v "${{ github.workspace }}":"/github/workspace" \
--workdir "/github/workspace" \
us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest \
bash -c '
echo "Successfully running inside the custom container from GAR!"
echo "Current working directory:"
pwd
echo "Contents of current directory:"
ls -la
echo "Verifying JAX installation..."
python3 -c "import jax; print(f\"JAX backend: {jax.default_backend()}\"); print(f\"JAX devices: {jax.devices()}\")"
'

74 changes: 74 additions & 0 deletions .github/workflows/tpu/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
set -e
set -x

cd "${KOKORO_ROOT}/"

sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1

PYTHON_BINARY="/usr/bin/python3.10"

"${PYTHON_BINARY}" -m venv venv
source venv/bin/activate
# Check the python version
python --version
python3 --version

cd "src/github/keras"
pip install -U pip setuptools
# psutil is used by background log reader
pip install -U psutil

if [ "$KERAS_BACKEND" == "tensorflow" ]
then
echo "TensorFlow backend detected."
pip install -r requirements-tensorflow-tpu.txt --progress-bar off --timeout 1000
pip uninstall -y keras keras-nightly
echo "Check that TensorFlow uses TPU"
python3 -c 'import tensorflow as tf;print(tf.__version__);print(tf.config.list_physical_devices("TPU"))'
# Raise error if GPU is not detected.
python3 -c 'import tensorflow as tf;assert len(tf.config.list_physical_devices("TPU")) > 0'

# TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted
pytest keras --ignore keras/src/applications \
--ignore keras/src/layers/merging/merging_test.py \
--cov=keras \
--cov-config=pyproject.toml
fi

if [ "$KERAS_BACKEND" == "jax" ]
then
echo "JAX backend detected."
pip install -r requirements-jax-cuda.txt --progress-bar off --timeout 1000
pip uninstall -y keras keras-nightly
python3 -c 'import jax;print(jax.__version__);print(jax.default_backend())'
# Raise error if GPU is not detected.
python3 -c 'import jax;assert jax.default_backend().lower() == "gpu"'

# TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted
# TODO: keras/trainers/data_adapters/py_dataset_adapter_test.py::PyDatasetAdapterTest::test_basic_flow0 Fatal Python error: Aborted
# keras/backend/jax/distribution_lib_test.py is configured for CPU test for now.
pytest keras --ignore keras/src/applications \
--ignore keras/src/layers/merging/merging_test.py \
--ignore keras/src/trainers/data_adapters/py_dataset_adapter_test.py \
--ignore keras/src/backend/jax/distribution_lib_test.py \
--ignore keras/src/distribution/distribution_lib_test.py \
--cov=keras \
--cov-config=pyproject.toml

pytest keras/src/distribution/distribution_lib_test.py --cov=keras --cov-config=pyproject.toml
fi

if [ "$KERAS_BACKEND" == "torch" ]
then
echo "PyTorch backend detected."
pip install -r requirements-torch-cuda.txt --progress-bar off --timeout 1000
pip uninstall -y keras keras-nightly
python3 -c 'import torch;print(torch.__version__);print(torch.cuda.is_available())'
# Raise error if GPU is not detected.
python3 -c 'import torch;assert torch.cuda.is_available()'

pytest keras --ignore keras/src/applications \
--cov=keras \
--cov-config=pyproject.toml

fi
16 changes: 16 additions & 0 deletions .github/workflows/tpu/tensorflow/continuous.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
build_file: "keras/.github/workflows/tpu/build.sh"

action {
define_artifacts {
regex: "**/sponge_log.log"
regex: "**/sponge_log.xml"
}
}

env_vars: {
key: "KERAS_BACKEND"
value: "tensorflow"
}

# Set timeout to 60 mins from default 180 mins
timeout_mins: 60
16 changes: 16 additions & 0 deletions .github/workflows/tpu/tensorflow/presubmit.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
build_file: "keras/.github/workflows/tpu/build.sh"

action {
define_artifacts {
regex: "**/sponge_log.log"
regex: "**/sponge_log.xml"
}
}

env_vars: {
key: "KERAS_BACKEND"
value: "tensorflow"
}

# Set timeout to 60 mins from default 180 mins
timeout_mins: 60
28 changes: 28 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
FROM --platform=linux/amd64 python:3.10-slim

ENV KERAS_HOME=/github/workspace/.github/workflows/config/jax \
KERAS_BACKEND=jax

RUN apt-get update && apt-get install -y --no-install-recommends \
git \
sudo \
&& rm -rf /var/lib/apt/lists/*

# Copy the entire codebase into the container
COPY . /github/workspace
WORKDIR /github/workspace

# Create and activate venv, install pip/setuptools/psutil, then run tests
# RUN cd ./keras/src/github/keras && \
RUN pip install --no-cache-dir -U pip setuptools && \
pip install --no-cache-dir -U psutil && \
pip install --no-cache-dir -r requirements-jax-tpu.txt && \
pip uninstall -y keras keras-nightly
# python3 -c 'import jax;print(jax.__version__);print(jax.default_backend())' && \
# python3 -c 'import jax;assert jax.default_backend().lower() == "tpu"' && \
# pytest keras --ignore keras/src/applications \
# --ignore keras/src/layers/merging/merging_test.py \
# --cov=keras \
# --cov-config=pyproject.toml

CMD ["/bin/bash"]
70 changes: 70 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def pytest_configure(config):
"markers",
"requires_trainable_backend: mark test for trainable backend only",
)
config.addinivalue_line(
"markers", "requires_tpu: mark test to run only on TPU"
)


def pytest_collection_modifyitems(config, items):
Expand Down Expand Up @@ -59,3 +62,70 @@ def pytest_collection_modifyitems(config, items):

def skip_if_backend(given_backend, reason):
return pytest.mark.skipif(backend() == given_backend, reason=reason)




def _cleanup_tpu_state():
import tensorflow as tf

try:
tf.config.experimental_disconnect_from_cluster()
except:
pass

try:
tf.config.experimental_reset_memory_stats("TPU_SYSTEM")
except:
pass


@pytest.fixture(scope="session")
def tpu_strategy_fixture():
import time

import tensorflow as tf

os.environ["TPU_NAME"] = "harshith-tf-4"
os.environ["JAX_PLATFORMS"] = ""
max_retries = int(os.environ.get("TPU_MAX_RETRIES", "3"))
base_delay = float(os.environ.get("TPU_BASE_DELAY", "2.0"))
tpu_available = False
strategy = None

for attempt in range(max_retries):
try:
print(f"TPU initialization attempt {attempt + 1}/{max_retries}")
_cleanup_tpu_state()
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
tpu_available = True
print("✓ TPU initialization successful!")
break
except (ValueError, RuntimeError, Exception) as e:
print(f"✗ TPU initialization attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
delay = base_delay * (2**attempt) + (attempt * 0.5)
print(f"Retrying in {delay:.1f} seconds...")
time.sleep(delay)
_cleanup_tpu_state()
else:
print("All TPU initialization attempts failed.")

if not tpu_available:
pytest.skip("TPU not available")

yield strategy

# Teardown
_cleanup_tpu_state()


@pytest.fixture(autouse=True)
def tpu(request):
marker = request.node.get_closest_marker("requires_tpu")
if marker:
strategy = request.getfixturevalue("tpu_strategy_fixture")
request.node.cls.tpu_strategy = strategy
Loading
Loading