|
7 | 7 | "source": [ |
8 | 8 | "# Flips Per Second\n", |
9 | 9 | "\n", |
10 | | - "In this example we will be looking as Flips Per Second (FPS) as a metric for evaluating the speed of a Gibbs sampling program. This highlights the power of GPUs for workloads which have many parallel Gibbs chains to be sampled." |
| 10 | + "In this example we will be looking as Flips Per Second (FPS) as a metric for evaluating the speed of a Gibbs sampling program." |
11 | 11 | ] |
12 | 12 | }, |
13 | 13 | { |
|
19 | 19 | "source": [ |
20 | 20 | "import time\n", |
21 | 21 | "\n", |
| 22 | + "import dwave_networkx\n", |
22 | 23 | "import jax\n", |
23 | 24 | "import jax.numpy as jnp\n", |
24 | | - "import dwave_networkx\n", |
| 25 | + "import matplotlib.pyplot as plt\n", |
25 | 26 | "import networkx as nx\n", |
26 | | - "from isax.block import BlockGraph, Edge, AbstractNode\n", |
27 | | - "from isax.sample import IsingModel, IsingSampler, SamplingArgs, sample_chain" |
| 27 | + "from isax import (\n", |
| 28 | + " BlockGraph,\n", |
| 29 | + " Edge,\n", |
| 30 | + " IsingModel,\n", |
| 31 | + " IsingSampler,\n", |
| 32 | + " Node,\n", |
| 33 | + " sample_chain,\n", |
| 34 | + " SamplingArgs,\n", |
| 35 | + ")" |
28 | 36 | ] |
29 | 37 | }, |
30 | 38 | { |
|
35 | 43 | "outputs": [], |
36 | 44 | "source": [ |
37 | 45 | "def create_dwave_pegasus_graph(pegasus_size, key):\n", |
38 | | - " \"\"\"Create DWave Pegasus graph like thrml does\"\"\"\n", |
39 | | - " # Create the DWave Pegasus graph\n", |
40 | 46 | " graph = dwave_networkx.pegasus_graph(pegasus_size)\n", |
41 | | - " \n", |
42 | | - " # Convert coordinates to AbstractNode objects (isax uses AbstractNode, not SpinNode)\n", |
43 | | - " coord_to_node = {coord: AbstractNode() for coord in graph.nodes}\n", |
| 47 | + " coord_to_node = {coord: Node() for coord in graph.nodes}\n", |
44 | 48 | " nx.relabel_nodes(graph, coord_to_node, copy=False)\n", |
45 | | - " \n", |
46 | | - " # Convert to our format\n", |
47 | 49 | " nodes = list(graph.nodes)\n", |
48 | 50 | " edges = [Edge(u, v) for u, v in graph.edges()]\n", |
49 | | - " \n", |
50 | | - " # Use graph coloring to create blocks (similar to thrml approach)\n", |
| 51 | + " # todo: dwave ships colorings already?\n", |
51 | 52 | " coloring = nx.coloring.greedy_color(graph, strategy=\"DSATUR\")\n", |
52 | 53 | " n_colors = max(coloring.values()) + 1\n", |
53 | 54 | " blocks = [[] for _ in range(n_colors)]\n", |
54 | | - " \n", |
55 | | - " # Form color groups\n", |
56 | 55 | " for node in graph.nodes:\n", |
57 | 56 | " blocks[coloring[node]].append(node)\n", |
58 | | - " \n", |
59 | | - " # Create BlockGraph with corrected argument order\n", |
60 | 57 | " block_graph = BlockGraph(blocks, edges)\n", |
61 | | - " \n", |
62 | | - " # Generate random parameters\n", |
63 | 58 | " key1, key2 = jax.random.split(key)\n", |
64 | 59 | " biases = jax.random.uniform(key1, (len(nodes),), minval=-0.1, maxval=0.1)\n", |
65 | 60 | " weights = jax.random.uniform(key2, (len(edges),), minval=-0.1, maxval=0.1)\n", |
66 | | - " \n", |
| 61 | + "\n", |
67 | 62 | " model = IsingModel(weights=weights, biases=biases)\n", |
68 | | - " \n", |
69 | | - " print(f\"Created Pegasus graph with {len(nodes)} nodes, {len(edges)} edges, {n_colors} color blocks\")\n", |
70 | | - " \n", |
| 63 | + "\n", |
| 64 | + " print(\n", |
| 65 | + " f\"Created Pegasus graph with {len(nodes)} nodes, {len(edges)} edges, {n_colors} color blocks\"\n", |
| 66 | + " )\n", |
| 67 | + "\n", |
71 | 68 | " return model, block_graph, nodes, blocks" |
72 | 69 | ] |
73 | 70 | }, |
|
82 | 79 | " model, block_graph, nodes, blocks, chain_len, batch_size, n_reps, device\n", |
83 | 80 | "):\n", |
84 | 81 | " key = jax.random.key(42)\n", |
85 | | - " \n", |
| 82 | + "\n", |
86 | 83 | " (adjs, masks, edge_infos), eqx_graph = block_graph.get_sampling_params()\n", |
87 | | - " \n", |
88 | | - " # Create samplers for each block\n", |
| 84 | + "\n", |
89 | 85 | " samplers = [IsingSampler() for _ in range(len(blocks))]\n", |
90 | | - " \n", |
| 86 | + "\n", |
91 | 87 | " sampling_args = SamplingArgs(\n", |
92 | 88 | " gibbs_steps=chain_len,\n", |
93 | 89 | " blocks_to_sample=list(range(len(blocks))),\n", |
94 | 90 | " adjs=adjs,\n", |
95 | 91 | " masks=masks,\n", |
96 | 92 | " edge_info=edge_infos,\n", |
97 | | - " eqx_graph=eqx_graph\n", |
| 93 | + " eqx_graph=eqx_graph,\n", |
98 | 94 | " )\n", |
99 | | - " \n", |
| 95 | + "\n", |
100 | 96 | " def sample_batch(key):\n", |
101 | 97 | " keys = jax.random.split(key, batch_size)\n", |
102 | | - " \n", |
| 98 | + "\n", |
103 | 99 | " def sample_single(single_key):\n", |
104 | 100 | " k_init, k_run = jax.random.split(single_key)\n", |
105 | | - " \n", |
106 | | - " # Initialize each block with random spins\n", |
| 101 | + "\n", |
107 | 102 | " init_state = []\n", |
108 | 103 | " for block in blocks:\n", |
109 | | - " block_state = jax.random.bernoulli(k_init, 0.5, (len(block),)).astype(jnp.int32) * 2 - 1\n", |
| 104 | + " block_state = (\n", |
| 105 | + " jax.random.bernoulli(k_init, 0.5, (len(block),)).astype(jnp.int32)\n", |
| 106 | + " * 2\n", |
| 107 | + " - 1\n", |
| 108 | + " )\n", |
110 | 109 | " init_state.append(block_state)\n", |
111 | | - " \n", |
| 110 | + "\n", |
112 | 111 | " samples = sample_chain(init_state, samplers, model, sampling_args, k_run)\n", |
113 | 112 | " return samples\n", |
114 | | - " \n", |
| 113 | + "\n", |
115 | 114 | " return jax.vmap(sample_single)(keys)\n", |
116 | | - " \n", |
| 115 | + "\n", |
117 | 116 | " jit_sample_batch = jax.jit(sample_batch, device=device)\n", |
118 | | - " \n", |
| 117 | + "\n", |
119 | 118 | " keys = jax.random.split(key, n_reps)\n", |
120 | | - " \n", |
| 119 | + "\n", |
121 | 120 | " start_time = time.time()\n", |
122 | 121 | " _ = jax.block_until_ready(jit_sample_batch(keys[0]))\n", |
123 | 122 | " time_with_compile = time.time() - start_time\n", |
|
132 | 131 | " thruput = chain_len * batch_size * len(nodes)\n", |
133 | 132 | " flips_per_ns = thruput / (time_without_compile * 1e9)\n", |
134 | 133 | "\n", |
135 | | - " print(\n", |
136 | | - " f\"chain_len: {chain_len}, batch_size: {batch_size}\"\n", |
137 | | - " )\n", |
| 134 | + " print(f\"chain_len: {chain_len}, batch_size: {batch_size}\")\n", |
138 | 135 | " print(\n", |
139 | 136 | " f\"Time with compile: {time_with_compile:.4f}s, \"\n", |
140 | 137 | " f\"time without compile: {time_without_compile:.4f}s, \"\n", |
|
241 | 238 | } |
242 | 239 | ], |
243 | 240 | "source": [ |
244 | | - "import matplotlib.pyplot as plt\n", |
245 | | - "\n", |
246 | 241 | "plt.figure(figsize=(6, 6))\n", |
247 | | - "plt.plot(batch_sizes, flips_per_ns_cpu, label=\"isax CPU\", marker=\"s\")\n", |
248 | 242 | "if has_gpu:\n", |
249 | 243 | " plt.plot(batch_sizes, flips_per_ns_gpu, label=\"isax GPU\", marker=\"o\")\n", |
250 | | - "# plt.axhline(\n", |
251 | | - "# 143.80, linestyle=\"--\", color=\"red\", label=\"FPGA\"\n", |
252 | | - "# )\n", |
253 | | - "# plt.axhline(\n", |
254 | | - "# 12.88, linestyle=\"--\", color=\"green\", label=\"Single TPU\"\n", |
255 | | - "# )\n", |
256 | | - "# plt.axhline(\n", |
257 | | - "# 60, linestyle=\"--\", color=\"orange\", label=\"FPGA Hardware (thrml reference)\"\n", |
258 | | - "# )\n", |
| 244 | + "plt.plot(batch_sizes, flips_per_ns_cpu, label=\"isax CPU\", marker=\"s\")\n", |
259 | 245 | "plt.legend()\n", |
260 | 246 | "plt.xlabel(\"Batch size\")\n", |
261 | 247 | "plt.xscale(\"log\")\n", |
|
0 commit comments