Skip to content

Commit 55888e6

Browse files
committed
work
1 parent 8517142 commit 55888e6

File tree

13 files changed

+143
-160
lines changed

13 files changed

+143
-160
lines changed

.github/FUNDING.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
github: [lockwo]

.github/workflows/style-check.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ jobs:
3232
python-version: ${{ matrix.python-version }}
3333
- name: Run styles & lint check
3434
run: |
35-
pip install -e .
36-
pip install ruff pyright
35+
pip install -e .[testing]
3736
ruff format --check isax/ tests/ examples/
3837
ruff check isax/ tests/ examples/
3938
pyright isax/

README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,40 @@ pip install -e .
1212

1313
Requires Python 3.10+.
1414

15+
## Quick Example
16+
17+
```python
18+
import jax
19+
import jax.numpy as jnp
20+
from isax import BlockGraph, Edge, Node, IsingModel, IsingSampler, SamplingArgs, sample_chain
21+
22+
L = 4
23+
nodes = [Node() for _ in range(L * L)]
24+
25+
edges = []
26+
for x in range(L):
27+
for y in range(L):
28+
i = x * L + y
29+
edges.append(Edge(nodes[i], nodes[(x * L + (y + 1) % L)]))
30+
edges.append(Edge(nodes[i], nodes[((x + 1) % L) * L + y]))
31+
32+
even = [nodes[x * L + y] for x in range(L) for y in range(L) if (x + y) % 2 == 0]
33+
odd = [nodes[x * L + y] for x in range(L) for y in range(L) if (x + y) % 2 == 1]
34+
35+
graph = BlockGraph([even, odd], edges)
36+
params = graph.get_sampling_params()
37+
38+
model = IsingModel(weights=jnp.ones(len(edges)), biases=jnp.zeros(L * L))
39+
sampler = IsingSampler()
40+
sampling_args = SamplingArgs(gibbs_steps=100, blocks_to_sample=[0, 1], data=params)
41+
42+
key = jax.random.key(0)
43+
init_state = [jax.random.choice(key, jnp.array([-1, 1]), (len(even),)),
44+
jax.random.choice(key, jnp.array([-1, 1]), (len(odd),))]
45+
46+
samples = sample_chain(init_state, [sampler, sampler], model, sampling_args, key)
47+
```
48+
1549
## Documentation
1650

1751
Available at https://lockwo.github.io/isax
@@ -22,6 +56,7 @@ Available at https://lockwo.github.io/isax
2256
- [x] cleaner interface
2357
- [ ] improve example documentation/math background
2458
- [ ] add tests
59+
- [ ] runtime sampling params
2560
- [ ] generic block typing
2661
- [ ] generalize pytree typing for states
2762
- [ ] support non-gibbs samplers (wolff, mh, etc.)

docs/api/block.md

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,22 @@
1-
# Block Graph Module
1+
# Block Graph Tools
22

33
::: isax.block.Node
44
options:
5-
show_source: false
6-
show_root_heading: true
7-
members:
8-
- __init__
9-
- __lt__
5+
members: false
106

117
::: isax.block.Edge
128
options:
13-
show_source: false
14-
show_root_heading: true
159
members:
1610
- __init__
1711

1812
::: isax.block.EqxGraph
1913
options:
20-
show_source: false
21-
show_root_heading: true
22-
members:
23-
- node_to_global
24-
- node_to_local
25-
- block_to_global
14+
members: false
15+
2616

2717
::: isax.block.BlockGraph
2818
options:
29-
show_source: false
30-
show_root_heading: true
3119
members:
3220
- __init__
33-
- get_edge_info
21+
- get_edge_structure
3422
- get_sampling_params
35-
- node_states_to_block_states
36-
- block_states_to_node_states

docs/api/metrics.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
# Metrics Module
1+
# Metrics
22

33
Functions for computing physical observables and metrics from sampled states.
44

5-
::: isax.metrics.magnetization_per_site
6-
options:
7-
show_source: false
5+
::: isax.metrics.magnetization_per_site

docs/api/sample.md

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,27 @@
1-
# Sampling Module
2-
3-
This module contains the core sampling functionality including models, samplers, and utilities.
1+
# Sampling
42

53
## Models
64

75
::: isax.sample.IsingModel
86
options:
9-
show_source: false
10-
show_root_heading: true
117
members:
12-
- __init__
138
- energy
149
- to_sample_params
1510

1611
## Samplers
1712

1813
::: isax.sample.AbstractSampler
1914
options:
20-
show_source: false
21-
show_root_heading: true
22-
members:
23-
- sample
24-
- initialize_state
15+
members: false
2516

2617
::: isax.sample.IsingSampler
2718
options:
28-
show_source: false
29-
show_root_heading: true
3019
members:
3120
- sample
3221
- initialize_state
3322

3423
::: isax.sample.AnnealedIsingSampler
3524
options:
36-
show_source: false
37-
show_root_heading: true
3825
members:
3926
- __init__
4027
- sample
@@ -44,19 +31,8 @@ This module contains the core sampling functionality including models, samplers,
4431

4532
::: isax.sample.SamplingArgs
4633
options:
47-
show_source: false
48-
show_root_heading: true
4934
members:
5035
- __init__
5136

5237
::: isax.sample.sample_chain
53-
options:
54-
show_source: false
55-
56-
::: isax.sample.sample_blocks
57-
options:
58-
show_source: false
5938

60-
::: isax.sample.concat_state
61-
options:
62-
show_source: false

docs/index.md

Lines changed: 2 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,16 @@
1-
# ISAX
1+
# Getting Started
22

3-
Block-parallel sampling for Ising models on arbitrary graphs using JAX.
4-
5-
ISAX is a [JAX](https://github.com/google/jax)-based library for sampling from Ising models using block-parallel Gibbs sampling. It supports hypergraphs, custom sampling schedules, and JAX transformations.
6-
7-
## Quick Example
8-
9-
```python
10-
import jax
11-
import jax.numpy as jnp
12-
from isax import IsingModel, IsingSampler, BlockGraph, sample_chain, SamplingArgs
13-
14-
# Create a 2D lattice Ising model
15-
n_nodes = 100
16-
edges = [(i, j) for i in range(n_nodes) for j in range(i+1, n_nodes)
17-
if abs(i-j) == 1 or abs(i-j) == 10] # 2D grid connectivity
18-
19-
# Initialize model with random weights
20-
key = jax.random.PRNGKey(0)
21-
weight_key, sample_key = jax.random.split(key)
22-
weights = jax.random.normal(weight_key, (len(edges),))
23-
biases = jnp.zeros(n_nodes)
24-
25-
model = IsingModel(weights=weights, biases=biases)
26-
27-
# Create block structure for parallel sampling
28-
graph = BlockGraph(n_nodes, edges, n_blocks=4)
29-
30-
# Sample using Gibbs sampler
31-
sampler = IsingSampler()
32-
initial_state = jax.random.choice(sample_key, jnp.array([-1, 1]), (n_nodes,))
33-
34-
# Run sampling
35-
samples = sample_chain(
36-
initial_state,
37-
[sampler] * graph.n_blocks,
38-
model,
39-
SamplingArgs(gibbs_steps=1000, blocks_to_sample=list(range(4)), data=graph.get_sampling_params()),
40-
sample_key
41-
)
42-
```
3+
isax is a [JAX](https://github.com/google/jax)-based library for sampling from Ising models using blocked Gibbs sampling. It supports hypergraphs, flexible sampling/modeling, and all the usual JAX transformations. isax is heavily inspired by [thrml](https://docs.thrml.ai/en/latest/).
434

445
## Installation
456

46-
```bash
47-
pip install isax
48-
```
49-
50-
Or install from source:
51-
527
```bash
538
git clone https://github.com/lockwo/isax
549
cd isax
5510
pip install -e .
5611
```
5712

58-
## Features
59-
60-
Traditional MCMC sampling for Ising models can be slow for large graphs. ISAX accelerates this through:
61-
62-
1. **Block decomposition**: Divide the graph into blocks that can be updated in parallel
63-
2. **JAX compilation**: JIT-compile the sampling loops for maximum performance
64-
3. **Vectorization**: Process multiple chains simultaneously with `vmap`
65-
4. **Hardware acceleration**: Automatic GPU/TPU support through JAX
66-
67-
## Navigation
68-
69-
- [API Reference](api/block.md) - Complete API documentation
70-
- [Examples](examples/01_ising_model_physics.ipynb) - Jupyter notebooks with practical examples
7113

7214
## Citation
7315

7416
If you use ISAX in your research, please cite:
75-
76-
```bibtex
77-
@software{isax2025,
78-
title = {ISAX: Block-parallel sampling for Ising models},
79-
year = {2025},
80-
url = {https://github.com/lockwo/isax}
81-
}
82-
```

examples/03_ising_thermodynamics.ipynb

Lines changed: 16 additions & 36 deletions
Large diffs are not rendered by default.

examples/04_paoa_optimization.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"source": [
77
"# Probabilistic Approximate Optimization Algorithm\n",
88
"\n",
9-
"In this example, we recreate Figure 3(c) of https://arxiv.org/abs/2507.07420."
9+
"In this example, we recreate Figure 3(c) of [https://arxiv.org/abs/2507.07420](https://arxiv.org/abs/2507.07420)."
1010
]
1111
},
1212
{

mkdocs.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ theme:
88
- content.code.copy
99
palette:
1010
# Light mode / dark mode
11-
# We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as
12-
# (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle.
1311
- scheme: default
1412
primary: white
1513
accent: amber
@@ -26,9 +24,9 @@ theme:
2624
repo: fontawesome/brands/github # GitHub logo in top right
2725
logo: "material/graph-outline"
2826

29-
site_name: ISAX
27+
site_name: isax
3028
site_description: Block-parallel sampling for Ising models on arbitrary graphs using JAX.
31-
site_author: ISAX Contributors
29+
site_author: lockwo
3230
site_url: https://github.com/lockwo/isax
3331

3432
repo_url: https://github.com/lockwo/isax

0 commit comments

Comments
 (0)