Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ These are some tutorials to get you started with Haliax. They are available as C
* [Mixed Precision with `jmp`](https://colab.research.google.com/drive/1_4cikwt-UhSH7yRzNRK8ze9msM9r2mEl?usp=sharing) (This one is really a tutorial for [jmp](https://github.com/deepmind/jmp) but it's how to use it with Haliax...)

<!--haliax-tutorials-end-->

### Examples

These example notebooks illustrate small self-contained uses of Haliax.

<!--haliax-examples-start-->
* [MNIST classification with synthetic data](https://github.com/stanford-crfm/haliax/blob/main/docs/mnist_example.ipynb)
<!--haliax-examples-end-->
### API Reference

Haliax's API documentation is available at [haliax.readthedocs.io](https://haliax.readthedocs.io/en/latest/).
Expand Down
5 changes: 5 additions & 0 deletions docs/examples.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
These example notebooks can be run locally to demonstrate basic Haliax usage.

{%
include-markdown "../README.md" start="<!--haliax-examples-start-->" end="<!--haliax-examples-end-->"
%}
159 changes: 159 additions & 0 deletions docs/mnist_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "8cfdbe4b",
"metadata": {},
"source": [
"# MNIST Example with Haliax\n",
"This notebook trains a simple neural network on MNIST using Haliax."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "84409554",
"metadata": {
"execution": {
"iopub.execute_input": "2025-07-04T20:25:37.026198Z",
"iopub.status.busy": "2025-07-04T20:25:37.025957Z",
"iopub.status.idle": "2025-07-04T20:25:38.021610Z",
"shell.execute_reply": "2025-07-04T20:25:38.019338Z"
}
},
"outputs": [],
"source": [
"import jax, jax.numpy as jnp\n",
"import haliax as hax\n",
"import haliax.nn as hnn\n",
"import equinox as eqx\n",
"from jax import random\n",
"import optax\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cac06d71",
"metadata": {
"execution": {
"iopub.execute_input": "2025-07-04T20:25:38.044803Z",
"iopub.status.busy": "2025-07-04T20:25:38.042415Z",
"iopub.status.idle": "2025-07-04T20:25:38.059738Z",
"shell.execute_reply": "2025-07-04T20:25:38.053521Z"
}
},
"outputs": [],
"source": [
"# Generate synthetic data (num_batches of random images)\n",
"num_batches = 10\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "08324579",
"metadata": {
"execution": {
"iopub.execute_input": "2025-07-04T20:25:38.086710Z",
"iopub.status.busy": "2025-07-04T20:25:38.086082Z",
"iopub.status.idle": "2025-07-04T20:25:38.100063Z",
"shell.execute_reply": "2025-07-04T20:25:38.096324Z"
}
},
"outputs": [],
"source": [
"Batch = hax.Axis('batch', 128)\n",
"Image = hax.Axis('image', 28*28)\n",
"Hidden = hax.Axis('hidden', 256)\n",
"Classes = hax.Axis('classes', 10)\n",
"\n",
"class Net(eqx.Module):\n",
" mlp: hnn.MLP\n",
" @staticmethod\n",
" def init(key):\n",
" mlp = hnn.MLP.init(Image, Classes, width=Hidden, depth=2, key=key)\n",
" return Net(mlp)\n",
" def __call__(self, x):\n",
" return self.mlp(x)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "01b72725",
"metadata": {
"execution": {
"iopub.execute_input": "2025-07-04T20:25:38.106923Z",
"iopub.status.busy": "2025-07-04T20:25:38.106657Z",
"iopub.status.idle": "2025-07-04T20:25:38.158969Z",
"shell.execute_reply": "2025-07-04T20:25:38.153673Z"
}
},
"outputs": [],
"source": [
"def loss_fn(model, images, labels):\n",
" imgs = hax.NamedArray(images.reshape(-1, 28*28), (Batch, Image))\n",
" logits = model(imgs)\n",
" loss = hnn.cross_entropy_loss(logits, Classes, labels)\n",
" return loss.mean().scalar()\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d2c07739",
"metadata": {
"execution": {
"iopub.execute_input": "2025-07-04T20:25:38.190928Z",
"iopub.status.busy": "2025-07-04T20:25:38.189543Z",
"iopub.status.idle": "2025-07-04T20:25:42.141138Z",
"shell.execute_reply": "2025-07-04T20:25:42.137318Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch done\n"
]
}
],
"source": [
"key = random.PRNGKey(0)\n",
"model = Net.init(key)\n",
"opt = optax.adam(1e-3)\n",
"opt_state = opt.init(model)\n",
"\n",
"for epoch in range(1):\n",
" for _ in range(num_batches):\n",
" key, subkey1, subkey2 = random.split(key, 3)\n",
" images = random.normal(subkey1, (Batch.size, 28*28))\n",
" label_ids = random.randint(subkey2, (Batch.size,), 0, Classes.size)\n",
" labels = hnn.one_hot(hax.NamedArray(label_ids, (Batch,)), Classes)\n",
" grads = jax.grad(loss_fn)(model, images, labels)\n",
" updates, opt_state = opt.update(grads, opt_state, params=model)\n",
" model = eqx.apply_updates(model, updates)\n",
" print(\"epoch done\")\n"
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ nav:
- "Distributed Training and FSDP": https://colab.research.google.com/drive/1QX4yH3zRFF3Xiibf1aahETcSQ5nbcUMz
- "Tensor Parallelism": https://colab.research.google.com/drive/18_BrtDpe1lu89M4T6fKzda8DdSLtFJhi
- "Mixed Precision with `jmp`": https://colab.research.google.com/drive/1_4cikwt-UhSH7yRzNRK8ze9msM9r2mEl?usp=sharing
- Examples: 'examples.md'
- Cheatsheet: 'cheatsheet.md'
- Named Arrays:
- Broadcasting: 'broadcasting.md'
Expand Down
Loading