Skip to content

Commit d4c11d8

Browse files
committed
work
1 parent e36f4d1 commit d4c11d8

File tree

4 files changed

+30
-7
lines changed

4 files changed

+30
-7
lines changed

.github/workflows/deploy-docs.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ jobs:
2626
- name: Install dependencies
2727
run: |
2828
python -m pip install --upgrade pip
29-
python -m pip install .
30-
python -m pip install mkdocs mkdocs-material mkdocstrings[python] pymdown-extensions
29+
python -m pip install ".[docs]"
3130
# https://github.com/mhausenblas/mkdocs-deploy-gh-pages/blob/master/action.sh
3231
- name: Build docs
3332
run: |

examples/03_ising_thermodynamics.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@
344344
"# jax.random.choice(k_init, jnp.array([1, -1]), shape=(num_odd,)),\n",
345345
"# ]\n",
346346
"\n",
347-
"# model_with_beta = IsingModel(weights=beta * edge_weights, \n",
347+
"# model_with_beta = IsingModel(weights=beta * edge_weights,\n",
348348
"# biases=beta * node_biases)\n",
349349
"\n",
350350
"# samples = sample_fn(\n",

pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ testing = [
4141
"ruff==0.11.0",
4242
"pyright==1.1.399",
4343
]
44+
docs = [
45+
"hippogriffe==0.2.0",
46+
"griffe==1.7.3",
47+
"mkdocs==1.6.1",
48+
"mkdocs-include-exclude-files==0.1.0",
49+
"mkdocs-ipynb==0.1.0",
50+
"mkdocs-material==9.6.7",
51+
"mkdocstrings[python]==0.28.3",
52+
"pymdown-extensions==10.14.3",
53+
]
4454

4555
[tool.setuptools]
4656
packages = { find = { include = ["isax", "isax.*"] } }

tests/test_sample.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
import unittest
22
import jax
33
import jax.numpy as jnp
4-
from isax import BlockGraph, Edge, Node, IsingModel, IsingSampler, SamplingArgs, sample_chain
4+
from isax import (
5+
BlockGraph,
6+
Edge,
7+
Node,
8+
IsingModel,
9+
IsingSampler,
10+
SamplingArgs,
11+
sample_chain,
12+
)
513

614

715
class TestSampleChain(unittest.TestCase):
@@ -17,23 +25,29 @@ def test_readme_example(self):
1725
edges.append(Edge(nodes[i], nodes[(x * L + (y + 1) % L)]))
1826
edges.append(Edge(nodes[i], nodes[((x + 1) % L) * L + y]))
1927

20-
even = [nodes[x * L + y] for x in range(L) for y in range(L) if (x + y) % 2 == 0]
28+
even = [
29+
nodes[x * L + y] for x in range(L) for y in range(L) if (x + y) % 2 == 0
30+
]
2131
odd = [nodes[x * L + y] for x in range(L) for y in range(L) if (x + y) % 2 == 1]
2232

2333
graph = BlockGraph([even, odd], edges)
2434
params = graph.get_sampling_params()
2535

2636
model = IsingModel(weights=jnp.ones(len(edges)), biases=jnp.zeros(L * L))
2737
sampler = IsingSampler()
28-
sampling_args = SamplingArgs(gibbs_steps=100, blocks_to_sample=[0, 1], data=params)
38+
sampling_args = SamplingArgs(
39+
gibbs_steps=100, blocks_to_sample=[0, 1], data=params
40+
)
2941

3042
key = jax.random.key(0)
3143
init_state = [
3244
jax.random.choice(key, jnp.array([-1, 1]), (len(even),)),
3345
jax.random.choice(key, jnp.array([-1, 1]), (len(odd),)),
3446
]
3547

36-
samples = sample_chain(init_state, [sampler, sampler], model, sampling_args, key)
48+
samples = sample_chain(
49+
init_state, [sampler, sampler], model, sampling_args, key
50+
)
3751

3852
self.assertEqual(len(samples), 2)
3953
self.assertEqual(samples[0].shape, (100, len(even)))

0 commit comments

Comments
 (0)