Skip to content

Commit 6d5ab63

Browse files
committed
copy over comrade integrate
1 parent 8833ea0 commit 6d5ab63

File tree

12 files changed

+543
-40
lines changed

12 files changed

+543
-40
lines changed

.github/workflows/CI.yml

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -34,40 +34,40 @@ jobs:
3434
fail-fast: false
3535
matrix:
3636
version:
37-
- "1.10"
38-
- "1.11"
37+
# - "1.10"
38+
# - "1.11"
3939
- "1.12"
4040
# - 'nightly'
4141
os:
4242
- ubuntu-24.04
4343
# `ubuntu-22.04-arm` is considered more stable than `ubuntu-24.04-arm`:
4444
# <https://github.com/orgs/community/discussions/148648#discussioncomment-12099554>.
45-
- ubuntu-22.04-arm
45+
# - ubuntu-22.04-arm
4646
# Disable `macOS-13` until
4747
# <https://github.com/EnzymeAD/Reactant.jl/issues/867> is resolved.
4848
# - macOS-13
49-
- macOS-latest
50-
- windows-latest
51-
- linux-x86-ct6e-180-4tpu
52-
test_group:
53-
- core
54-
- nn
55-
- integration
56-
- probprog
49+
# - macOS-latest
50+
# - windows-latest
51+
# - linux-x86-ct6e-180-4tpu
52+
# test_group:
53+
# - core
54+
# - nn
55+
# - integration
5756
runtime:
58-
- "pjrt"
57+
# - "pjrt"
5958
- "ifrt"
60-
exclude:
61-
- os: linux-x86-ct6e-180-4tpu
62-
version: "1.10"
63-
- os: linux-x86-ct6e-180-4tpu
64-
runtime: "pjrt"
59+
# exclude:
60+
# - os: linux-x86-ct6e-180-4tpu
61+
# version: "1.10"
62+
# - os: linux-x86-ct6e-180-4tpu
63+
# runtime: "pjrt"
6564
uses: ./.github/workflows/CommonCI.yml
6665
with:
6766
julia_version: ${{ matrix.version }}
6867
os: ${{ matrix.os }}
6968
runtime: ${{ matrix.runtime }}
70-
test_args: ${{ matrix.test_group == 'core' && 'core plugins' || matrix.test_group }}
69+
# test_args: ${{ matrix.test_group == 'core' && 'core plugins' || matrix.test_group }}
70+
test_args: "integration/Comrade"
7171

7272
# This has been broken for a while, originating from CUDA.jl
7373
# test-assertions:
@@ -88,22 +88,21 @@ jobs:
8888
# assertions: true
8989
# test_args: ${{ matrix.test_group == 'core' && 'core plugins' || matrix.test_group }}
9090

91-
downgrade:
92-
strategy:
93-
fail-fast: false
94-
matrix:
95-
test_group:
96-
- core
97-
- nn
98-
- integration
99-
- probprog
100-
runtime:
101-
- "pjrt"
102-
- "ifrt"
103-
uses: ./.github/workflows/CommonCI.yml
104-
with:
105-
julia_version: "1.10"
106-
os: "ubuntu-24.04"
107-
runtime: ${{ matrix.runtime }}
108-
test_args: ${{ matrix.test_group == 'core' && 'core plugins' || matrix.test_group }}
109-
downgrade_testing: true
91+
# downgrade:
92+
# strategy:
93+
# fail-fast: false
94+
# matrix:
95+
# test_group:
96+
# - core
97+
# - nn
98+
# - integration
99+
# runtime:
100+
# - "pjrt"
101+
# - "ifrt"
102+
# uses: ./.github/workflows/CommonCI.yml
103+
# with:
104+
# julia_version: "1.10"
105+
# os: "ubuntu-24.04"
106+
# runtime: ${{ matrix.runtime }}
107+
# test_args: ${{ matrix.test_group == 'core' && 'core plugins' || matrix.test_group }}
108+
# downgrade_testing: true

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,6 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
5959
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
6060
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6161

62-
[sources]
63-
ReactantCore = {path = "lib/ReactantCore"}
64-
6562
[extensions]
6663
ReactantAbstractFFTsExt = "AbstractFFTs"
6764
ReactantArrayInterfaceExt = "ArrayInterface"
@@ -147,5 +144,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
147144
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
148145
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
149146

147+
[sources.ReactantCore]
148+
path = "lib/ReactantCore"
149+
150150
[workspace]
151151
projects = ["docs", "test", "benchmark"]

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2727
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
2828
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
2929
ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc"
30+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3031
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
3132
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
3233
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -68,6 +69,7 @@ OffsetArrays = "1"
6869
OneHotArrays = "0.2.6"
6970
Optimisers = "0.4"
7071
ParallelTestRunner = "2.1"
72+
Pkg = "1.10"
7173
Preferences = "1.4"
7274
PythonCall = "0.9"
7375
Random = "1.10"
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[deps]
2+
python = "<3.11,>=3.9,<4"
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
[deps]
2+
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
3+
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
4+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
5+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
6+
Comrade = "99d987ce-9a1e-4df8-bc0b-1ea019aa547b"
7+
ComradeBase = "6d8c423b-a35f-4ef1-850c-862fe21f82c4"
8+
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
9+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
10+
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
11+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
12+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
13+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
15+
NFFT = "efe261a4-0d2b-5849-be55-fc731d526b0d"
16+
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
17+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
18+
Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468"
19+
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
20+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
21+
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
22+
VLBIImagePriors = "b1ba175b-8447-452c-b961-7db2d6f7a029"
23+
VLBILikelihoods = "90db92cd-0007-4c0a-8e51-dbf0782ce592"
24+
VLBISkyModels = "d6343c73-7174-4e0f-bb64-562643efbeca"
25+
26+
[sources.Comrade]
27+
rev = "ptiede-reactant"
28+
url = "https://github.com/ptiede/Comrade.jl"
29+
30+
[sources.ComradeBase]
31+
rev = "main"
32+
url = "https://github.com/ptiede/ComradeBase.jl"
33+
34+
[sources.Reactant]
35+
path = "../../.."
36+
37+
[sources.TransformVariables]
38+
rev = "ptiede-reactant"
39+
url = "https://github.com/ptiede/TransformVariables.jl"
40+
41+
[sources.NFFT]
42+
rev = "ptiede-reactant"
43+
url = "https://github.com/ptiede/NFFT.jl"
44+
45+
[sources.VLBIImagePriors]
46+
rev = "ptiede-reactantperf"
47+
url = "https://github.com/ptiede/VLBIImagePriors.jl"
48+
49+
[sources.VLBILikelihoods]
50+
rev = "ptiede-reactant"
51+
url = "https://github.com/ptiede/VLBILikelihoods.jl"
52+
53+
[sources.VLBISkyModels]
54+
rev = "ptiede-copyto"
55+
url = "https://github.com/EHTJulia/VLBISkyModels.jl"
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import sys
2+
3+
sys.path.insert(0, "Serialized/Fwd")
4+
sys.path.insert(0, "Serialized/Bwd")
5+
6+
7+
import jax
8+
jax.config.update('jax_enable_x64', True)
9+
import jax.numpy as jnp
10+
import blackjax
11+
import enzyme_ad
12+
from functools import partial
13+
14+
import logdensityof as lg
15+
import gl as gl
16+
17+
from logdensityof import run_logdensityof
18+
from gl import run_gl
19+
20+
lg_inputs = lg.load_inputs()
21+
gl_inputs = gl.load_inputs()
22+
23+
tpost = lg_inputs[:-1]
24+
xr = lg_inputs[-1]
25+
26+
27+
28+
29+
jlr = jax.jit(run_logdensityof)
30+
31+
jtpost0 = jnp.array(tpost[0])
32+
jtpost1 = jnp.array(tpost[1])
33+
jtpost2 = jnp.array(tpost[2])
34+
jtpost3 = jnp.array(tpost[3])
35+
jtpost4 = jnp.array(tpost[4])
36+
jxr = jnp.array(xr)
37+
38+
out = jlr(jtpost0, jtpost1, jtpost2, jtpost3, jtpost4, xr)
39+
40+
run_logdensityof(jtpost0, jtpost1, jtpost2, jtpost3, jtpost4, jxr)
41+
run_gl(jtpost0, jtpost1, jtpost2, jtpost3, jtpost4, jxr)
42+
43+
44+
@jax.custom_vjp
45+
def f(x):
46+
out = run_logdensityof(jtpost0, jtpost1, jtpost2, jtpost3, jtpost4, x)
47+
return out[0]
48+
49+
def f_fwd(x):
50+
j = run_gl(jtpost0, jtpost1, jtpost2, jtpost3, jtpost4, x)[0]
51+
return f(x), (j,)
52+
53+
def f_bwd(res, g):
54+
j = res[0]
55+
return (g * j,)
56+
57+
f.defvjp(f_fwd, f_bwd)
58+
59+
logdensity = lambda x: f(**x)
60+
61+
inv_mass_matrix = jnp.ones(len(jxr))
62+
initial_position = {"x": jxr}
63+
64+
rng_key, sample_key = jax.random.split(jax.random.PRNGKey(0))
65+
66+
# adaptation
67+
warmup = blackjax.window_adaptation(blackjax.nuts, logdensity, progress_bar=True)
68+
rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3)
69+
(state, parameters), _ = warmup.run(warmup_key, initial_position, num_steps=1000)
70+
71+
72+
def inference_loop(rng_key, kernel, init, nsamples):
73+
@jax.jit
74+
def step(state, rng_key):
75+
state, _ = kernel(rng_key, state)
76+
return state, state
77+
78+
keys = jax.random.split(rng_key, nsamples)
79+
_, states = jax.lax.scan(step, init, keys)
80+
return states
81+
82+
# inference loop
83+
rng_key, sample_key = jax.random.split(jax.random.PRNGKey(0))
84+
kernel = blackjax.nuts(logdensity, **parameters).step
85+
states = inference_loop(sample_key, kernel, state, nsamples=1000)
86+
87+
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[project]
2+
name = "sampling"
3+
version = "0.1.0"
4+
description = "Add your description here"
5+
readme = "README.md"
6+
requires-python = ">=3.10, <3.12"
7+
dependencies = [
8+
"blackjax>=1.3",
9+
"ipython>=8.38.0",
10+
"jax[cuda12]==0.5.0",
11+
]

0 commit comments

Comments
 (0)