Skip to content

Commit c31d00d

Browse files
committed
work
1 parent 3774cd6 commit c31d00d

13 files changed

+1263
-465
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,5 @@ cython_debug/
205205
marimo/_static/
206206
marimo/_lsp/
207207
__marimo__/
208+
209+
.DS_Store

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,10 @@
11
# isax
2+
23
Ising (like) models in JAX
4+
5+
todo:
6+
- annealing
7+
- cleaner interface
8+
- generic block typing
9+
- generalize pytree typing for states
10+
- support non-gibbs samplers (wolff, mh, etc.)

examples/01_ising_model_physics.ipynb

Lines changed: 250 additions & 0 deletions
Large diffs are not rendered by default.

examples/02_fps.ipynb

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"source": [
88
"# Flips Per Second\n",
99
"\n",
10-
"In this example we will be looking as Flips Per Second (FPS) as a metric for evaluating the speed of a Gibbs sampling program. This highlights the power of GPUs for workloads which have many parallel Gibbs chains to be sampled."
10+
"In this example we will be looking as Flips Per Second (FPS) as a metric for evaluating the speed of a Gibbs sampling program."
1111
]
1212
},
1313
{
@@ -19,12 +19,20 @@
1919
"source": [
2020
"import time\n",
2121
"\n",
22+
"import dwave_networkx\n",
2223
"import jax\n",
2324
"import jax.numpy as jnp\n",
24-
"import dwave_networkx\n",
25+
"import matplotlib.pyplot as plt\n",
2526
"import networkx as nx\n",
26-
"from isax.block import BlockGraph, Edge, AbstractNode\n",
27-
"from isax.sample import IsingModel, IsingSampler, SamplingArgs, sample_chain"
27+
"from isax import (\n",
28+
" BlockGraph,\n",
29+
" Edge,\n",
30+
" IsingModel,\n",
31+
" IsingSampler,\n",
32+
" Node,\n",
33+
" sample_chain,\n",
34+
" SamplingArgs,\n",
35+
")"
2836
]
2937
},
3038
{
@@ -35,39 +43,28 @@
3543
"outputs": [],
3644
"source": [
3745
"def create_dwave_pegasus_graph(pegasus_size, key):\n",
38-
" \"\"\"Create DWave Pegasus graph like thrml does\"\"\"\n",
39-
" # Create the DWave Pegasus graph\n",
4046
" graph = dwave_networkx.pegasus_graph(pegasus_size)\n",
41-
" \n",
42-
" # Convert coordinates to AbstractNode objects (isax uses AbstractNode, not SpinNode)\n",
43-
" coord_to_node = {coord: AbstractNode() for coord in graph.nodes}\n",
47+
" coord_to_node = {coord: Node() for coord in graph.nodes}\n",
4448
" nx.relabel_nodes(graph, coord_to_node, copy=False)\n",
45-
" \n",
46-
" # Convert to our format\n",
4749
" nodes = list(graph.nodes)\n",
4850
" edges = [Edge(u, v) for u, v in graph.edges()]\n",
49-
" \n",
50-
" # Use graph coloring to create blocks (similar to thrml approach)\n",
51+
" # todo: dwave ships colorings already?\n",
5152
" coloring = nx.coloring.greedy_color(graph, strategy=\"DSATUR\")\n",
5253
" n_colors = max(coloring.values()) + 1\n",
5354
" blocks = [[] for _ in range(n_colors)]\n",
54-
" \n",
55-
" # Form color groups\n",
5655
" for node in graph.nodes:\n",
5756
" blocks[coloring[node]].append(node)\n",
58-
" \n",
59-
" # Create BlockGraph with corrected argument order\n",
6057
" block_graph = BlockGraph(blocks, edges)\n",
61-
" \n",
62-
" # Generate random parameters\n",
6358
" key1, key2 = jax.random.split(key)\n",
6459
" biases = jax.random.uniform(key1, (len(nodes),), minval=-0.1, maxval=0.1)\n",
6560
" weights = jax.random.uniform(key2, (len(edges),), minval=-0.1, maxval=0.1)\n",
66-
" \n",
61+
"\n",
6762
" model = IsingModel(weights=weights, biases=biases)\n",
68-
" \n",
69-
" print(f\"Created Pegasus graph with {len(nodes)} nodes, {len(edges)} edges, {n_colors} color blocks\")\n",
70-
" \n",
63+
"\n",
64+
" print(\n",
65+
" f\"Created Pegasus graph with {len(nodes)} nodes, {len(edges)} edges, {n_colors} color blocks\"\n",
66+
" )\n",
67+
"\n",
7168
" return model, block_graph, nodes, blocks"
7269
]
7370
},
@@ -82,42 +79,44 @@
8279
" model, block_graph, nodes, blocks, chain_len, batch_size, n_reps, device\n",
8380
"):\n",
8481
" key = jax.random.key(42)\n",
85-
" \n",
82+
"\n",
8683
" (adjs, masks, edge_infos), eqx_graph = block_graph.get_sampling_params()\n",
87-
" \n",
88-
" # Create samplers for each block\n",
84+
"\n",
8985
" samplers = [IsingSampler() for _ in range(len(blocks))]\n",
90-
" \n",
86+
"\n",
9187
" sampling_args = SamplingArgs(\n",
9288
" gibbs_steps=chain_len,\n",
9389
" blocks_to_sample=list(range(len(blocks))),\n",
9490
" adjs=adjs,\n",
9591
" masks=masks,\n",
9692
" edge_info=edge_infos,\n",
97-
" eqx_graph=eqx_graph\n",
93+
" eqx_graph=eqx_graph,\n",
9894
" )\n",
99-
" \n",
95+
"\n",
10096
" def sample_batch(key):\n",
10197
" keys = jax.random.split(key, batch_size)\n",
102-
" \n",
98+
"\n",
10399
" def sample_single(single_key):\n",
104100
" k_init, k_run = jax.random.split(single_key)\n",
105-
" \n",
106-
" # Initialize each block with random spins\n",
101+
"\n",
107102
" init_state = []\n",
108103
" for block in blocks:\n",
109-
" block_state = jax.random.bernoulli(k_init, 0.5, (len(block),)).astype(jnp.int32) * 2 - 1\n",
104+
" block_state = (\n",
105+
" jax.random.bernoulli(k_init, 0.5, (len(block),)).astype(jnp.int32)\n",
106+
" * 2\n",
107+
" - 1\n",
108+
" )\n",
110109
" init_state.append(block_state)\n",
111-
" \n",
110+
"\n",
112111
" samples = sample_chain(init_state, samplers, model, sampling_args, k_run)\n",
113112
" return samples\n",
114-
" \n",
113+
"\n",
115114
" return jax.vmap(sample_single)(keys)\n",
116-
" \n",
115+
"\n",
117116
" jit_sample_batch = jax.jit(sample_batch, device=device)\n",
118-
" \n",
117+
"\n",
119118
" keys = jax.random.split(key, n_reps)\n",
120-
" \n",
119+
"\n",
121120
" start_time = time.time()\n",
122121
" _ = jax.block_until_ready(jit_sample_batch(keys[0]))\n",
123122
" time_with_compile = time.time() - start_time\n",
@@ -132,9 +131,7 @@
132131
" thruput = chain_len * batch_size * len(nodes)\n",
133132
" flips_per_ns = thruput / (time_without_compile * 1e9)\n",
134133
"\n",
135-
" print(\n",
136-
" f\"chain_len: {chain_len}, batch_size: {batch_size}\"\n",
137-
" )\n",
134+
" print(f\"chain_len: {chain_len}, batch_size: {batch_size}\")\n",
138135
" print(\n",
139136
" f\"Time with compile: {time_with_compile:.4f}s, \"\n",
140137
" f\"time without compile: {time_without_compile:.4f}s, \"\n",
@@ -241,21 +238,10 @@
241238
}
242239
],
243240
"source": [
244-
"import matplotlib.pyplot as plt\n",
245-
"\n",
246241
"plt.figure(figsize=(6, 6))\n",
247-
"plt.plot(batch_sizes, flips_per_ns_cpu, label=\"isax CPU\", marker=\"s\")\n",
248242
"if has_gpu:\n",
249243
" plt.plot(batch_sizes, flips_per_ns_gpu, label=\"isax GPU\", marker=\"o\")\n",
250-
"# plt.axhline(\n",
251-
"# 143.80, linestyle=\"--\", color=\"red\", label=\"FPGA\"\n",
252-
"# )\n",
253-
"# plt.axhline(\n",
254-
"# 12.88, linestyle=\"--\", color=\"green\", label=\"Single TPU\"\n",
255-
"# )\n",
256-
"# plt.axhline(\n",
257-
"# 60, linestyle=\"--\", color=\"orange\", label=\"FPGA Hardware (thrml reference)\"\n",
258-
"# )\n",
244+
"plt.plot(batch_sizes, flips_per_ns_cpu, label=\"isax CPU\", marker=\"s\")\n",
259245
"plt.legend()\n",
260246
"plt.xlabel(\"Batch size\")\n",
261247
"plt.xscale(\"log\")\n",

0 commit comments

Comments
 (0)