Skip to content

TESTING ONLY (TPU test) #21528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 56 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
1d7c685
added requirements-tensorflow-tpu.txt and tpu configuration in .kokoro
kharshith-k Jun 16, 2025
19b5e6b
updated .kokoro/github/ubuntu/tpu/build.sh with jax and torch backend…
kharshith-k Jun 16, 2025
d203ca3
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 18, 2025
f45e5d0
Changed the tpu CI config files path to .github from .kokoro
kharshith-k Jun 18, 2025
6771cc0
Added new job in .github/workflows/actions.yml to run TPU tests
kharshith-k Jun 18, 2025
87d36e7
fixed runs-on option in acvtions.yml for tpu_build job to run on self…
kharshith-k Jun 18, 2025
9901298
Added another runner in the actions TPU job
kharshith-k Jun 18, 2025
be97210
Update continuous.cfg
kharshith-k Jun 18, 2025
a1cd5c3
Update presubmit.cfg
kharshith-k Jun 18, 2025
c5e3a5c
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 23, 2025
f0ab676
Update actions.yml
kharshith-k Jun 23, 2025
09161d7
Developed Dockerfile for TPU build job in actions.yml
kharshith-k Jun 24, 2025
9a3948f
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 24, 2025
058fdff
Update actions.yml
kharshith-k Jun 24, 2025
d47e39e
Included few more runners in tpu_build job
kharshith-k Jun 26, 2025
a6a59d7
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 26, 2025
ba4f6ae
Using linux-x86-ct6e-44-1tpu
kharshith-k Jun 26, 2025
a5a3624
Modified requirement-commmon.txt and updated requirements-tensorflow-…
kharshith-k Jun 30, 2025
b9998af
Added Dtypes_TPU_tests.py and requirements-jax-tpu.txt
kharshith-k Jul 22, 2025
f68be97
Progress bar now handles `steps_per_execution`. (#21422)
hertschuh Jun 26, 2025
1018abf
Fix symbolic call of `logsumexp` with int axis. (#21428)
hertschuh Jun 27, 2025
0da77e4
Only allow deserialization of `KerasSaveable`s by module and name. (#…
hertschuh Jun 29, 2025
cb639c5
commented tensorflow deps
kharshith-k Jul 2, 2025
c0d1743
Added log of dtypes_test_tpu.py and the test script for the same
kharshith-k Jul 2, 2025
306e6e7
modified dtypes_test_tpu.py as per pre-commit standards
kharshith-k Jul 2, 2025
4e584fc
Added TPU initiaization and teardown functionalities in conftest.py, …
kharshith-k Jul 3, 2025
bb09e95
Added dtypes_test_TPU.py and dtypes_new_test.py, modified conftest.py
kharshith-k Jul 9, 2025
8a63d09
Added Dcokerfile and tests list command
kharshith-k Jul 23, 2025
4651454
Updated Dockerfile
kharshith-k Jul 28, 2025
40af241
Restored Dockerfile to previous changes
kharshith-k Jul 28, 2025
64420d5
updated actions.yml file to install and configure docker engine on se…
kharshith-k Jul 28, 2025
da84de5
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jul 28, 2025
d69277d
updated actions.yml file to include container option
kharshith-k Jul 28, 2025
1c307fc
updated actions.yml file to include container option without volume b…
kharshith-k Jul 28, 2025
693886b
updated actions.yml file to change TPU
kharshith-k Jul 28, 2025
e74b851
Updated container path in build-and-test-on-tpu job
kharshith-k Jul 29, 2025
d31b3c4
seperated TPU workflow from actions.yml
kharshith-k Jul 29, 2025
a70d19e
updated trigger condition for TPU tests workflow
kharshith-k Jul 29, 2025
5f5b609
updated container usage configuration for TPU tests workflow
kharshith-k Jul 29, 2025
72e729f
updated env vars for TPU tests workflow
kharshith-k Jul 29, 2025
e129299
updated env vars parsing syntax in TPU tests workflow
kharshith-k Jul 29, 2025
3fe5b57
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
10df307
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
dd21e09
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
328628f
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
01f0c17
updated image name in TPU tests workflow
kharshith-k Jul 29, 2025
3e41c37
updated image name with generic ubuntu image
kharshith-k Jul 29, 2025
5e55c2c
updated tpu-tests to use ghcr
kharshith-k Jul 29, 2025
ea9ff88
updated tpu-tests to store built image as local tar
kharshith-k Jul 29, 2025
6d92aa9
updated image name from ubuntu:22.04 to docker:24.0-cli in tpu tests …
kharshith-k Jul 29, 2025
3c75bf8
updated image name from docker:24.0-cli to ubuntu:22.04 in tpu tests…
kharshith-k Jul 29, 2025
1589a75
added volume mount from host in load-and-test-job
kharshith-k Jul 29, 2025
36bd682
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jul 29, 2025
04112cf
Reverted tpu-tests.yml to version using ghcr.io for image storage
kharshith-k Jul 29, 2025
2acd44c
[OpenVINO backend] fix openvino model exported names to match keras n…
Mohamed-Ashraf273 Jul 29, 2025
3899730
update access token
sachinprasadhs Aug 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,101 @@ on:

permissions:
contents: read
# id-token: write

# env:
# PYTHON: ${{ matrix.python-version }}
# KERAS_HOME: .github/workflows/config/${{ matrix.backend }}
# KERAS_BACKEND: jax
# PROJECT_ID: gtech-rmi-dev # Replace with your GCP project ID
# GAR_LOCATION: us-central1 # Replace with your Artifact Registry location (e.g., us-central1)
# IMAGE_REPO: keras-docker-images
# IMAGE_NAME: keras-jax-tpu-amd64:latest # Name of your Docker image
# TPU_VM_NAME: kharshith-jax-tpu # Replace with your TPU VM instance name
# TPU_VM_ZONE: us-central1-b # Replace with your TPU VM zone

jobs:
# build-and-test-on-tpu:
# strategy:
# fail-fast: false
# matrix:
# python-version: ['3.10']
# backend: [jax]
# name: Run TPU tests
# runs-on:
# # - keras-jax-tpu-runner
# # - linux-x86-ct5lp-112-4tpu
# # - linux-x86-ct5lp-112-4tpu-fvn6n-runner-6kb8n
# - linux-x86-ct6e-44-1tpu
# # - linux-x86-ct6e-44-1tpu-4khbn-runner-x4st4
# # - linux-x86-ct6e-44-1tpu-4khbn-runner-45nmc

# container: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest

# # container:
# # image: docker:latest # Provides the Docker CLI within the job container
# # volumes:
# # - /var/run/docker.sock:/var/run/docker.sock # Mounts host's Docker socket for control
# # options: --privileged

# steps:
# - name: Checkout Repository
# uses: actions/checkout@v4

# - name: Set up Docker BuildX
# uses: docker/setup-buildx-action@v3

# - name: Authenticate to Google Cloud (Workload Identity Federation)
# id: 'auth'
# uses: 'google-github-actions/auth@v2'
# with:
# # Replace with your Workload Identity Federation provider details.
# # This service account needs 'Artifact Registry Writer' role.
# workload_identity_provider: 'projects/YOUR_PROJECT_NUMBER/locations/global/workloadIdentityPools/YOUR_POOL_ID/providers/YOUR_PROVIDER_ID'
# service_account: 'your-github-actions-sa@${{ env.PROJECT_ID }}.iam.gserviceaccount.com'

# - name: Configure Docker to use Google Artifact Registry
# run: gcloud auth configure-docker ${{ env.GAR_LOCATION }}-docker.pkg.dev

# - name: Build Docker Image
# run: |
# IMAGE_TAG="${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:${{ github.sha }}"
# echo "Building Docker image: $IMAGE_TAG"
# docker build \
# --platform=linux/amd64 \
# -f .github/workflows/tpu/Dockerfile \
# -t "$IMAGE_TAG" \
# -t "${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:latest" \
# .
# echo "Built Docker image: $IMAGE_TAG"
# echo "LOCAL_TEST_IMAGE_TAG=$IMAGE_TAG" >> $GITHUB_ENV # Store for immediate use in run step

# - name: Push Docker Image to Artifact Registry
# run: |
# echo "Pushing Docker image to Artifact Registry: ${{ env.LOCAL_TEST_IMAGE_TAG }}"
# docker push "${{ env.LOCAL_TEST_IMAGE_TAG }}"
# docker push "${{ env.GAR_LOCATION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/${{ env.IMAGE_REPO }}/${{ env.IMAGE_NAME_BASE }}:latest"
# echo "Pushed Docker image."

# - name: Run Docker container and execute tests on TPU
# run: |
# echo "Running Docker container with TPU access and executing tests..."
# docker run --rm \
# --privileged \
# --network host \
# -e PYTHON=3.10 \ # Use a specific version or derive from matrix
# -e KERAS_HOME=.github/workflows/config/jax \
# -e KERAS_BACKEND=jax \
# ${{ env.LOCAL_TEST_IMAGE_TAG }} \
# /bin/bash -c ' \
# echo "Verifying JAX TPU backend inside container..." && \
# python3 -c "import jax; print(\"JAX Version:\", jax.__version__); print(\"Default Backend:\", jax.default_backend()); assert jax.default_backend().lower() == \"tpu\", \"TPU backend not found or not default\"; print(\"TPU verification successful!\")" \
# # Add your actual pytest command here. Ensure pytest is installed inside your Docker image.
# # && pytest keras --ignore keras/src/applications --ignore keras/src/layers/merging/merging_test.py --cov=keras --cov-config=pyproject.toml
# '
# echo "Docker container finished running tests."


build:
strategy:
fail-fast: false
Expand Down
79 changes: 79 additions & 0 deletions .github/workflows/tpu-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@

name: TPU Tests

on:
push:
branches: [ master ]
pull_request:
release:
types: [created]

permissions:
contents: read # Only read permission is needed for checkout
packages: write

# Define base environment variables at the workflow level
# These can still be used inside steps, just not for the container image definition
env:
PROJECT_ID: gtech-rmi-dev
GAR_LOCATION: us-central1
GAR_REPO: keras-docker-images
IMAGE_NAME: keras-jax-tpu-amd64
IMAGE_TAG: latest

jobs:

build-and-push:
name: Build and Push to GHCR
runs-on: ubuntu-latest # This job doesn't need the special TPU runner

steps:
- name: Checkout Repository
uses: actions/checkout@v4

- name: Log in to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
# GITHUB_TOKEN is automatically created by Actions and has permissions to push to your repo's package registry.
username: ${{ github.actor }}
password: ${{ secrets.GHCR_PAT }}

- name: Build and Push Docker Image
uses: docker/build-push-action@v6
with:
context: .
# Push the image to ghcr.io
push: true
# Create a unique tag using the commit SHA for this specific build
tags: ghcr.io/${{ github.repository }}:${{ github.sha }}


test-in-container:
name: Test in Custom Container
# This job must run after the build-and-push job is complete
needs: build-and-push

# Use the required TPU runner
runs-on: linux-x86-ct6e-44-1tpu

# CRITICAL: Use the container image we just pushed in the previous job.
# This satisfies the runner's requirement for a container to be specified.
container:
image: ghcr.io/${{ github.repository }}:${{ github.sha }}
options: --privileged --network host

steps:
- name: Checkout Repository
uses: actions/checkout@v4
# We need the code available inside the container's workspace to run tests.

- name: Run Verification and Tests
run: |
echo "Successfully running inside the custom container from GHCR!"
echo "Current working directory:"
pwd
echo "Contents of current directory:"
ls -la
echo "Verifying JAX installation..."
python3 -c "import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices: {jax.devices()}')"
74 changes: 74 additions & 0 deletions .github/workflows/tpu/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
set -e
set -x

cd "${KOKORO_ROOT}/"

sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1

PYTHON_BINARY="/usr/bin/python3.10"

"${PYTHON_BINARY}" -m venv venv
source venv/bin/activate
# Check the python version
python --version
python3 --version

cd "src/github/keras"
pip install -U pip setuptools
# psutil is used by background log reader
pip install -U psutil

if [ "$KERAS_BACKEND" == "tensorflow" ]
then
echo "TensorFlow backend detected."
pip install -r requirements-tensorflow-tpu.txt --progress-bar off --timeout 1000
pip uninstall -y keras keras-nightly
echo "Check that TensorFlow uses TPU"
python3 -c 'import tensorflow as tf;print(tf.__version__);print(tf.config.list_physical_devices("TPU"))'
# Raise error if GPU is not detected.
python3 -c 'import tensorflow as tf;assert len(tf.config.list_physical_devices("TPU")) > 0'

# TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted
pytest keras --ignore keras/src/applications \
--ignore keras/src/layers/merging/merging_test.py \
--cov=keras \
--cov-config=pyproject.toml
fi

if [ "$KERAS_BACKEND" == "jax" ]
then
echo "JAX backend detected."
pip install -r requirements-jax-cuda.txt --progress-bar off --timeout 1000
pip uninstall -y keras keras-nightly
python3 -c 'import jax;print(jax.__version__);print(jax.default_backend())'
# Raise error if GPU is not detected.
python3 -c 'import jax;assert jax.default_backend().lower() == "gpu"'

# TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted
# TODO: keras/trainers/data_adapters/py_dataset_adapter_test.py::PyDatasetAdapterTest::test_basic_flow0 Fatal Python error: Aborted
# keras/backend/jax/distribution_lib_test.py is configured for CPU test for now.
pytest keras --ignore keras/src/applications \
--ignore keras/src/layers/merging/merging_test.py \
--ignore keras/src/trainers/data_adapters/py_dataset_adapter_test.py \
--ignore keras/src/backend/jax/distribution_lib_test.py \
--ignore keras/src/distribution/distribution_lib_test.py \
--cov=keras \
--cov-config=pyproject.toml

pytest keras/src/distribution/distribution_lib_test.py --cov=keras --cov-config=pyproject.toml
fi

if [ "$KERAS_BACKEND" == "torch" ]
then
echo "PyTorch backend detected."
pip install -r requirements-torch-cuda.txt --progress-bar off --timeout 1000
pip uninstall -y keras keras-nightly
python3 -c 'import torch;print(torch.__version__);print(torch.cuda.is_available())'
# Raise error if GPU is not detected.
python3 -c 'import torch;assert torch.cuda.is_available()'

pytest keras --ignore keras/src/applications \
--cov=keras \
--cov-config=pyproject.toml

fi
16 changes: 16 additions & 0 deletions .github/workflows/tpu/tensorflow/continuous.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
build_file: "keras/.github/workflows/tpu/build.sh"

action {
define_artifacts {
regex: "**/sponge_log.log"
regex: "**/sponge_log.xml"
}
}

env_vars: {
key: "KERAS_BACKEND"
value: "tensorflow"
}

# Set timeout to 60 mins from default 180 mins
timeout_mins: 60
16 changes: 16 additions & 0 deletions .github/workflows/tpu/tensorflow/presubmit.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
build_file: "keras/.github/workflows/tpu/build.sh"

action {
define_artifacts {
regex: "**/sponge_log.log"
regex: "**/sponge_log.xml"
}
}

env_vars: {
key: "KERAS_BACKEND"
value: "tensorflow"
}

# Set timeout to 60 mins from default 180 mins
timeout_mins: 60
28 changes: 28 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
FROM --platform=linux/amd64 python:3.10-slim

ENV KERAS_HOME=/github/workspace/.github/workflows/config/jax \
KERAS_BACKEND=jax

RUN apt-get update && apt-get install -y --no-install-recommends \
git \
sudo \
&& rm -rf /var/lib/apt/lists/*

# Copy the entire codebase into the container
COPY . /github/workspace
WORKDIR /github/workspace

# Create and activate venv, install pip/setuptools/psutil, then run tests
# RUN cd ./keras/src/github/keras && \
RUN pip install --no-cache-dir -U pip setuptools && \
pip install --no-cache-dir -U psutil && \
pip install --no-cache-dir -r requirements-jax-tpu.txt && \
pip uninstall -y keras keras-nightly
# python3 -c 'import jax;print(jax.__version__);print(jax.default_backend())' && \
# python3 -c 'import jax;assert jax.default_backend().lower() == "tpu"' && \
# pytest keras --ignore keras/src/applications \
# --ignore keras/src/layers/merging/merging_test.py \
# --cov=keras \
# --cov-config=pyproject.toml

CMD ["/bin/bash"]
Loading
Loading