Skip to content

Commit 1359446

Browse files
authored
Merge branch 'main' into dependabot/cargo/rust/cargo-553fd7204f
2 parents d555a46 + 9df0972 commit 1359446

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+8900
-134
lines changed

.github/workflows/deploy.yml

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
name: Deploy to GitHub Pages
2+
3+
on:
4+
# Trigger the workflow every time you push to the `main` branch
5+
# Using a different branch name? Replace `main` with your branch’s name
6+
push:
7+
branches: [ main ]
8+
paths: [ website ]
9+
merge_group:
10+
# Allows you to run this workflow manually from the Actions tab on GitHub.
11+
workflow_dispatch:
12+
13+
# Allow this job to clone the repo and create a page deployment
14+
permissions:
15+
contents: read
16+
pages: write
17+
id-token: write
18+
19+
# Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued.
20+
concurrency:
21+
group: "pages"
22+
cancel-in-progress: false
23+
24+
jobs:
25+
build:
26+
runs-on: ubuntu-latest
27+
steps:
28+
- name: Checkout your repository using git
29+
uses: actions/checkout@v4
30+
- name: Install, build, and upload your site
31+
uses: withastro/action@v3
32+
with:
33+
path: website # The root location of your Astro project inside the repository. (optional)
34+
# node-version: 20 # The specific version of Node that should be used to build your site. Defaults to 20. (optional)
35+
# package-manager: pnpm@latest # The Node package manager that should be used to install dependencies and build your site. Automatically detected based on your lockfile. (optional)
36+
37+
deploy:
38+
needs: build
39+
runs-on: ubuntu-latest
40+
environment:
41+
name: github-pages
42+
url: ${{ steps.deployment.outputs.page_url }}
43+
steps:
44+
- name: Deploy to GitHub Pages
45+
id: deployment
46+
uses: actions/deploy-pages@v4

.github/workflows/maturin_ci.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@
33
#
44
# maturin generate-ci github
55
#
6+
# Prefer to run matrix of unit-tests in the pytest workflow until the following
7+
# issue is solved: https://github.com/PyO3/maturin/issues/1971
68
name: Maturin CI
79

810
on:
911
push:
1012
paths:
1113
- '**/*.py'
1214
- '**/*.rs'
13-
- 'pytest.ini'
15+
- 'pyproject.toml'
16+
- 'requirements.txt'
17+
- 'rust/Cargo.toml'
18+
- 'rust/Cargo.lock'
1419
merge_group:
1520
pull_request:
1621
types: [opened, synchronize, reopened]

.github/workflows/mypy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,4 @@ jobs:
5050
pip install -r test_requirements.txt
5151
mkdir -p .mypy_cache
5252
mypy --version
53-
mypy --no-color-output --install-types --non-interactive src
53+
mypy --no-color-output --install-types --non-interactive src docs

.github/workflows/piptest.yml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,16 @@ on:
1212

1313
jobs:
1414
piptesting:
15-
runs-on: ubuntu-22.04
15+
runs-on: ${{ matrix.platform.runner }}
16+
strategy:
17+
matrix:
18+
# ubuntu-24.04-arm is not stable enough
19+
platform:
20+
- runner: ubuntu-latest # x64
21+
- runner: windows-latest # x64
22+
- runner: macos-13 # Intel
23+
- runner: macos-14 # arm64
24+
- runner: macos-latest # arm64
1625
steps:
1726
- uses: actions/checkout@v4
1827
- name: Set up Python 3.10
@@ -26,4 +35,4 @@ jobs:
2635
- name: Run tutorial using sedpack pip package
2736
run: |
2837
python docs/tutorials/quick_start/mnist_save.py -d mnist_dataset
29-
python docs/tutorials/quick_start/mnist_read.py -d mnist_dataset
38+
python docs/tutorials/quick_start/mnist_read_keras.py -d mnist_dataset

.github/workflows/pytest.yml

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,16 @@ on:
1313

1414
jobs:
1515
unittesting:
16-
runs-on: ubuntu-22.04
16+
runs-on: ${{ matrix.platform.runner }}
17+
strategy:
18+
matrix:
19+
# ubuntu-20.04-arm was not stable enough when testing
20+
platform:
21+
- runner: ubuntu-latest # x64
22+
- runner: windows-latest # x64
23+
- runner: macos-13 # Intel
24+
- runner: macos-14 # arm64
25+
- runner: macos-latest # arm64
1726
if: github.event_name != 'schedule'
1827
steps:
1928
- uses: actions/checkout@v4
@@ -33,25 +42,30 @@ jobs:
3342
with:
3443
path: ${{ steps.pip-cache.outputs.dir }}
3544
# The cache key depends on requirements.txt
36-
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt') }}
37-
restore-keys: |
38-
${{ runner.os }}-pip-
45+
key: ${{ matrix.platform.runner }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements*.txt') }}-${{ hashFiles('test_requirements*.txt') }}
3946
# Build a virtualenv, but only if it doesn't already exist
4047
- name: Populate pip cache
41-
run: python -m pip install --require-hashes --no-deps -r requirements.txt
48+
# requirements.txt is not reliable since across different platforms and
49+
# their versions the pip package versions might vary. We regenerate it
50+
# again from pyproject.toml every time when pyproject.toml or
51+
# requirements.txt changes. The pinned versions in requirements.txt are
52+
# tested by coverage since that is running on ubuntu which is also used
53+
# to produce the main requirements.txt file.
54+
run: |
55+
pip install pip-tools
56+
pip-compile --generate-hashes pyproject.toml > requirements.txt
57+
pip install -r requirements.txt
58+
pip install -r test_requirements.txt
59+
if: steps.cache.outputs.cache-hit != 'true'
4260
- name: Save cache
4361
id: cache-save
4462
uses: actions/cache/save@v4
4563
with:
4664
path: ${{ steps.pip-cache.outputs.dir }}
4765
key: ${{ steps.cache.outputs.cache-primary-key }}
4866
if: steps.cache.outputs.cache-hit != 'true'
49-
- name: Installing test requirements and sedpack
50-
# Start by "installing" sedpack to be sure all dependencies are listed
51-
run: |
52-
pip install -r test_requirements.txt
53-
pip install --editable .
54-
echo "PYTHONPATH=./src:$PYTHONPATH" >> $GITHUB_ENV
67+
- name: Install sedpack locally
68+
run: pip install --editable .
5569
- name: Running unit tests
5670
run: |
5771
python -m pytest
@@ -76,12 +90,13 @@ jobs:
7690
with:
7791
path: ${{ steps.pip-cache.outputs.dir }}
7892
# The cache key depends on requirements.txt
79-
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt') }}
80-
restore-keys: |
81-
${{ runner.os }}-pip-
93+
key: ${{ runner.os }}-pip-${{ hashFiles('requirements*.txt') }}-${{ hashFiles('test_requirements*.txt') }}
8294
# Build a virtualenv, but only if it doesn't already exist
8395
- name: Populate pip cache
84-
run: python -m pip install --require-hashes --no-deps -r requirements.txt
96+
run: |
97+
python -m pip install --require-hashes --no-deps -r requirements.txt
98+
pip install -r test_requirements.txt
99+
if: steps.cache.outputs.cache-hit != 'true'
85100
- name: Save cache
86101
id: cache-save
87102
uses: actions/cache/save@v4
@@ -92,7 +107,6 @@ jobs:
92107
- name: Installing test requirements and sedpack
93108
# Start by "installing" sedpack to be sure all dependencies are listed
94109
run: |
95-
pip install -r test_requirements.txt
96110
pip install --editable .
97111
echo "PYTHONPATH=./src:$PYTHONPATH" >> $GITHUB_ENV
98112
- name: Install workflow dependencies

README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
[![Coverage Status](https://coveralls.io/repos/github/google/sedpack/badge.svg?branch=main)](https://coveralls.io/github/google/sedpack?branch=main)
44

5+
[Documentation](https://google.github.io/sedpack/)
6+
57
Mainly refactored from the [SCAAML](https://github.com/google/scaaml) project.
68

79
## Available components
@@ -53,10 +55,12 @@ Update: `pip-compile pyproject.toml --generate-hashes --upgrade` and commit requ
5355

5456
### Tutorial
5557

56-
Tutorials available in the docs/tutorials/ directory. For a "hello world" see
57-
[docs/tutorials/quick_start/mnist_save.py](https://github.com/google/sedpack/blob/main/docs/tutorials/quick_start/mnist_save.py)
58-
and
59-
[docs/tutorials/quick_start/mnist_save.py](https://github.com/google/sedpack/blob/main/docs/tutorials/quick_start/mnist_read.py).
58+
A tutorial and documentation is available at
59+
[https://google.github.io/sedpack/](https://google.github.io/sedpack/).
60+
61+
Code for the tutorials is available in the `docs/tutorials` directory. For a
62+
"hello world" see
63+
[https://google.github.io/sedpack/tutorials/mnist/](https://google.github.io/sedpack/tutorials/mnist/).
6064

6165
## Disclaimer
6266

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Read MNIST data and feed it to a neural network. For a tutorial with
15+
explanations see: https://google.github.io/sedpack/tutorials/mnist
16+
17+
Inspired by https://flax.readthedocs.io/en/latest/mnist_tutorial.html
18+
19+
Example use:
20+
python mnist_save.py -d "~/Datasets/my_new_dataset/"
21+
python mnist_read_jax.py -d "~/Datasets/my_new_dataset/"
22+
"""
23+
import argparse
24+
from functools import partial
25+
from typing import Any
26+
27+
from jax import Array
28+
from jax.typing import ArrayLike
29+
from flax import nnx
30+
import jax.numpy as jnp
31+
import optax
32+
from tqdm import tqdm
33+
34+
from sedpack.io import Dataset
35+
36+
37+
def process_batch(d: Any) -> dict[str, Array]:
38+
"""Turn the NumPy arrays into JAX arrays and reshape the input to have a
39+
channel.
40+
"""
41+
batch_size: int = d["input"].shape[0]
42+
return {
43+
"input": jnp.array(d["input"]).reshape(batch_size, 28, 28, 1),
44+
"digit": jnp.array(d["digit"], jnp.int32),
45+
}
46+
47+
48+
class CNN(nnx.Module): # type: ignore[misc]
49+
"""FLAX CNN model.
50+
"""
51+
52+
def __init__(self, *, rngs: nnx.Rngs) -> None:
53+
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
54+
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
55+
self.avg_pool = partial(nnx.avg_pool,
56+
window_shape=(2, 2),
57+
strides=(2, 2))
58+
self.linear1 = nnx.Linear(3_136, 256, rngs=rngs)
59+
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
60+
61+
def __call__(self, x: Array) -> Array:
62+
x = self.avg_pool(nnx.relu(self.conv1(x)))
63+
x = self.avg_pool(nnx.relu(self.conv2(x)))
64+
x = x.reshape(x.shape[0], -1) # flatten
65+
x = nnx.relu(self.linear1(x))
66+
x = self.linear2(x)
67+
return x
68+
69+
70+
def loss_fn(model: CNN, batch: dict[str, Array]) -> tuple[Array, Array]:
71+
logits = model(batch["input"])
72+
loss = optax.softmax_cross_entropy_with_integer_labels(
73+
logits=logits, labels=batch["digit"]).mean()
74+
return loss, logits
75+
76+
77+
@nnx.jit # type: ignore[misc]
78+
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric,
79+
batch: dict[str, ArrayLike]) -> None:
80+
"""Train for a single step.
81+
"""
82+
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
83+
(loss, logits), grads = grad_fn(model, batch)
84+
metrics.update(loss=loss, logits=logits, labels=batch["digit"])
85+
optimizer.update(grads)
86+
87+
88+
@nnx.jit # type: ignore[misc]
89+
def eval_step(
90+
model: CNN,
91+
metrics: nnx.MultiMetric,
92+
batch: dict[str, Array],
93+
) -> None:
94+
loss, logits = loss_fn(model, batch)
95+
metrics.update(loss=loss, logits=logits, labels=batch["digit"])
96+
97+
98+
def main() -> None:
99+
"""Train a neural network on the MNIST dataset saved in the sedpack
100+
format.
101+
"""
102+
parser = argparse.ArgumentParser(
103+
description=
104+
"Read MNIST in dataset-lib format and train a small neural network.")
105+
parser.add_argument("--dataset_directory",
106+
"-d",
107+
help="Where to load the dataset",
108+
required=True)
109+
parser.add_argument("--ascii_evaluations",
110+
"-e",
111+
help="How many images to print and evaluate",
112+
type=int,
113+
default=10)
114+
args = parser.parse_args()
115+
116+
model = CNN(rngs=nnx.Rngs(0))
117+
nnx.display(model)
118+
119+
learning_rate: float = 0.005
120+
momentum: float = 0.9
121+
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
122+
metrics = nnx.MultiMetric(
123+
accuracy=nnx.metrics.Accuracy(),
124+
loss=nnx.metrics.Average("loss"),
125+
)
126+
nnx.display(optimizer)
127+
128+
metrics_history: dict[str, list[Array]] = {
129+
"train_loss": [],
130+
"train_accuracy": [],
131+
"test_loss": [],
132+
"test_accuracy": [],
133+
}
134+
135+
dataset = Dataset(args.dataset_directory) # Load the dataset
136+
batch_size = 32
137+
train_data = dataset.as_tfdataset(
138+
"train",
139+
batch_size=batch_size,
140+
shuffle=1_000,
141+
)
142+
validation_data = dataset.as_tfdataset(
143+
"test", # validation split
144+
batch_size=batch_size,
145+
shuffle=1_000,
146+
repeat=False,
147+
)
148+
train_steps: int = 1_200
149+
eval_every: int = 200
150+
151+
for step, batch in enumerate(tqdm(train_data)):
152+
if step > train_steps:
153+
break
154+
155+
# Run the optimization for one step and make a stateful update to the
156+
# following:
157+
# - The train state's model parameters
158+
# - The optimizer state
159+
# - The training loss and accuracy batch metrics
160+
batch = process_batch(batch)
161+
train_step(model, optimizer, metrics, batch)
162+
163+
if step > 0 and (step % eval_every == 0 or step
164+
== train_steps - 1): # One training epoch has passed.
165+
# Log the training metrics.
166+
# Compute the metrics.
167+
for metric, value in metrics.compute().items():
168+
# Record the metrics.
169+
metrics_history[f"train_{metric}"].append(value)
170+
print(f"{metric} = {value}", end=" ")
171+
metrics.reset() # Reset the metrics for the test set.
172+
print()
173+
174+
# Compute the metrics on the test set after each training epoch.
175+
for test_batch in validation_data.as_numpy_iterator():
176+
test_batch = process_batch(test_batch)
177+
eval_step(model, metrics, test_batch)
178+
179+
# Log the test metrics.
180+
for metric, value in metrics.compute().items():
181+
metrics_history[f"test_{metric}"].append(value)
182+
print(f"test {metric} = {value}", end=" ")
183+
metrics.reset() # Reset the metrics for the next training epoch.
184+
print()
185+
186+
187+
if __name__ == "__main__":
188+
main()

0 commit comments

Comments
 (0)