Skip to content

Commit f3b2c31

Browse files
authored
Fix variance sampling and add tests (#8)
1 parent dc1a0f1 commit f3b2c31

File tree

10 files changed

+311
-115
lines changed

10 files changed

+311
-115
lines changed

.github/workflows/unit-test.yml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
name: unit-test
22

33
on:
4-
pull_request:
4+
schedule:
5+
- cron: "0 0 * * *"
6+
workflow_dispatch:
7+
# Trigger the workflow on push or pull request,
8+
# but only for the main branch
59
push:
6-
branches: [main]
10+
branches:
11+
- main
12+
pull_request:
13+
branches:
14+
- main
715

816
jobs:
917
build:

docker/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ hydra-core==1.3.1
66
optax==0.1.4
77
pandas==1.5.3
88
pre-commit==3.0.4
9+
protobuf==3.20.3 # https://github.com/tensorflow/datasets/issues/4858
910
pytest-cov==4.0.0
1011
pytest-xdist==3.1.0
1112
pytest==7.2.1

imgx/conf/config_amos_diffusion.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ task:
1616
name: "diffusion" # segmentation, diffusion
1717
diffusion:
1818
num_timesteps: 5
19-
num_timesteps_beta: 1000
19+
num_timesteps_beta: 1001
2020
beta:
2121
beta_schedule: "linear" # linear, quadradic, cosine, warmup10, warmup50
2222
beta_start: 0.0001

imgx/conf/config_amos_segmentation.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ task:
1616
name: "segmentation" # segmentation, diffusion
1717
diffusion:
1818
num_timesteps: 5
19-
num_timesteps_beta: 1000
19+
num_timesteps_beta: 1001
2020
beta:
2121
beta_schedule: "linear" # linear, quadradic, cosine, warmup10, warmup50
2222
beta_start: 0.0001

imgx/conf/config_pelvic_diffusion.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ task:
1616
name: "diffusion" # segmentation, diffusion
1717
diffusion:
1818
num_timesteps: 5
19-
num_timesteps_beta: 1000
19+
num_timesteps_beta: 1001
2020
beta:
2121
beta_schedule: "linear" # linear, quadradic, cosine, warmup10, warmup50
2222
beta_start: 0.0001

imgx/conf/config_pelvic_segmentation.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ task:
1616
name: "segmentation" # segmentation, diffusion
1717
diffusion:
1818
num_timesteps: 5
19-
num_timesteps_beta: 1000
19+
num_timesteps_beta: 1001
2020
beta:
2121
beta_schedule: "linear" # linear, quadradic, cosine, warmup10, warmup50
2222
beta_start: 0.0001

imgx/diffusion/gaussian_diffusion.py

Lines changed: 11 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,19 @@
1010
import haiku as hk
1111
import jax.numpy as jnp
1212
import jax.random
13-
import numpy as np
1413

1514
from imgx import EPS
15+
from imgx.diffusion.variance_schedule import (
16+
DiffusionBetaSchedule,
17+
downsample_beta_schedule,
18+
get_beta_schedule,
19+
)
1620
from imgx.metric.distribution import (
1721
discretized_gaussian_log_likelihood,
1822
normal_kl,
1923
)
2024

2125

22-
class DiffusionBetaSchedule(enum.Enum):
23-
"""Class to define beta schedule."""
24-
25-
LINEAR = enum.auto()
26-
QUADRADIC = enum.auto()
27-
COSINE = enum.auto()
28-
WARMUP10 = enum.auto()
29-
WARMUP50 = enum.auto()
30-
31-
3226
class DiffusionModelOutputType(enum.Enum):
3327
"""Class to define model's output meaning.
3428
@@ -88,90 +82,6 @@ def extract_and_expand(
8882
return jnp.expand_dims(arr[t], axis=tuple(range(1, ndim)))
8983

9084

91-
def get_beta_schedule(
92-
num_timesteps: int,
93-
beta_schedule: DiffusionBetaSchedule,
94-
beta_start: float,
95-
beta_end: float,
96-
) -> jnp.ndarray:
97-
"""Get variance (beta) schedule for q(x_t | x_{t-1}).
98-
99-
TODO: open-source code used float64 for beta.
100-
101-
Args:
102-
num_timesteps: number of time steps in total, T.
103-
beta_schedule: schedule for beta.
104-
beta_start: beta for t=0.
105-
beta_end: beta for t=T.
106-
107-
Raises:
108-
ValueError: for unknown schedule.
109-
"""
110-
if beta_schedule == DiffusionBetaSchedule.LINEAR:
111-
return jnp.linspace(
112-
beta_start,
113-
beta_end,
114-
num_timesteps,
115-
)
116-
if beta_schedule == DiffusionBetaSchedule.QUADRADIC:
117-
return (
118-
jnp.linspace(
119-
beta_start**0.5,
120-
beta_end**0.5,
121-
num_timesteps,
122-
)
123-
** 2
124-
)
125-
if beta_schedule == DiffusionBetaSchedule.COSINE:
126-
127-
def alphas_cumprod(t: float) -> float:
128-
"""Eq 17 in https://arxiv.org/abs/2102.09672."""
129-
return np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2
130-
131-
max_beta = 0.999
132-
betas = []
133-
for i in range(num_timesteps):
134-
t1 = i / num_timesteps
135-
t2 = (i + 1) / num_timesteps
136-
beta = min(1 - alphas_cumprod(t2) / alphas_cumprod(t1), max_beta)
137-
betas.append(beta)
138-
return jnp.array(betas)
139-
140-
if beta_schedule == DiffusionBetaSchedule.WARMUP10:
141-
num_timesteps_warmup = max(num_timesteps // 10, 1)
142-
betas_warmup = (
143-
jnp.linspace(
144-
beta_start**0.5,
145-
beta_end**0.5,
146-
num_timesteps_warmup,
147-
)
148-
** 2
149-
)
150-
return jnp.concatenate(
151-
[
152-
betas_warmup,
153-
jnp.ones((num_timesteps - num_timesteps_warmup,)) * beta_end,
154-
]
155-
)
156-
if beta_schedule == DiffusionBetaSchedule.WARMUP50:
157-
num_timesteps_warmup = max(num_timesteps // 2, 1)
158-
betas_warmup = (
159-
jnp.linspace(
160-
beta_start**0.5,
161-
beta_end**0.5,
162-
num_timesteps_warmup,
163-
)
164-
** 2
165-
)
166-
return jnp.concatenate(
167-
[
168-
betas_warmup,
169-
jnp.ones((num_timesteps - num_timesteps_warmup,)) * beta_end,
170-
]
171-
)
172-
raise ValueError(f"Unknown beta_schedule {beta_schedule}.")
173-
174-
17585
@dataclasses.dataclass
17686
class GaussianDiffusion(hk.Module):
17787
"""Class for Gaussian diffusion sampling.
@@ -228,24 +138,17 @@ def __init__(
228138

229139
# shape are all (T,)
230140
# corresponding to 0, ..., T-1, where 0 means one step
231-
self.betas = get_beta_schedule(
141+
betas = get_beta_schedule(
232142
num_timesteps=num_timesteps_beta,
233143
beta_schedule=beta_schedule,
234144
beta_start=beta_start,
235145
beta_end=beta_end,
236146
)
237-
if num_timesteps_beta % num_timesteps != 0:
238-
raise ValueError(
239-
f"num_timesteps_beta={num_timesteps_beta} "
240-
f"can't be evenly divided by num_timesteps={num_timesteps}."
241-
)
242-
if num_timesteps != num_timesteps_beta:
243-
# adjust beta
244-
step_scale = num_timesteps_beta // num_timesteps
245-
alphas = 1.0 - self.betas
246-
alphas_cumprod = jnp.cumprod(alphas)
247-
alphas_cumprod = alphas_cumprod[step_scale - 1 :: step_scale]
248-
self.betas = 1.0 - alphas_cumprod[1:] / alphas_cumprod[:-1]
147+
self.betas = downsample_beta_schedule(
148+
betas=betas,
149+
num_timesteps=num_timesteps_beta,
150+
num_timesteps_to_keep=num_timesteps,
151+
)
249152

250153
alphas = 1.0 - self.betas # alpha_t
251154
self.alphas_cumprod = jnp.cumprod(alphas) # \bar{alpha}_t
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""Variance schedule for diffusion models."""
2+
from __future__ import annotations
3+
4+
import enum
5+
6+
import numpy as np
7+
from jax import numpy as jnp
8+
9+
10+
class DiffusionBetaSchedule(enum.Enum):
11+
"""Class to define beta schedule."""
12+
13+
LINEAR = enum.auto()
14+
QUADRADIC = enum.auto()
15+
COSINE = enum.auto()
16+
WARMUP10 = enum.auto()
17+
WARMUP50 = enum.auto()
18+
19+
20+
def get_beta_schedule(
21+
num_timesteps: int,
22+
beta_schedule: DiffusionBetaSchedule,
23+
beta_start: float,
24+
beta_end: float,
25+
) -> jnp.ndarray:
26+
"""Get variance (beta) schedule for q(x_t | x_{t-1}).
27+
28+
Args:
29+
num_timesteps: number of time steps in total, T.
30+
beta_schedule: schedule for beta.
31+
beta_start: beta for t=0.
32+
beta_end: beta for t=T-1.
33+
34+
Returns:
35+
Shape (num_timesteps,) array of beta values, for t=0, ..., T-1.
36+
Values are in ascending order.
37+
38+
Raises:
39+
ValueError: for unknown schedule.
40+
"""
41+
if beta_schedule == DiffusionBetaSchedule.LINEAR:
42+
return jnp.linspace(
43+
beta_start,
44+
beta_end,
45+
num_timesteps,
46+
)
47+
if beta_schedule == DiffusionBetaSchedule.QUADRADIC:
48+
return (
49+
jnp.linspace(
50+
beta_start**0.5,
51+
beta_end**0.5,
52+
num_timesteps,
53+
)
54+
** 2
55+
)
56+
if beta_schedule == DiffusionBetaSchedule.COSINE:
57+
58+
def f(t: float) -> float:
59+
"""Eq 17 in https://arxiv.org/abs/2102.09672.
60+
61+
Args:
62+
t: time step with values in [0, 1].
63+
64+
Returns:
65+
Cumulative product of alpha.
66+
"""
67+
return np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2
68+
69+
betas = [0.0]
70+
alphas_cumprod_prev = 1.0
71+
for i in range(1, num_timesteps):
72+
t = i / (num_timesteps - 1)
73+
alphas_cumprod = f(t)
74+
beta = 1 - alphas_cumprod / alphas_cumprod_prev
75+
betas.append(beta)
76+
return jnp.array(betas) * (beta_end - beta_start) + beta_start
77+
78+
if beta_schedule == DiffusionBetaSchedule.WARMUP10:
79+
num_timesteps_warmup = max(num_timesteps // 10, 1)
80+
betas_warmup = (
81+
jnp.linspace(
82+
beta_start**0.5,
83+
beta_end**0.5,
84+
num_timesteps_warmup,
85+
)
86+
** 2
87+
)
88+
return jnp.concatenate(
89+
[
90+
betas_warmup,
91+
jnp.ones((num_timesteps - num_timesteps_warmup,)) * beta_end,
92+
]
93+
)
94+
if beta_schedule == DiffusionBetaSchedule.WARMUP50:
95+
num_timesteps_warmup = max(num_timesteps // 2, 1)
96+
betas_warmup = (
97+
jnp.linspace(
98+
beta_start**0.5,
99+
beta_end**0.5,
100+
num_timesteps_warmup,
101+
)
102+
** 2
103+
)
104+
return jnp.concatenate(
105+
[
106+
betas_warmup,
107+
jnp.ones((num_timesteps - num_timesteps_warmup,)) * beta_end,
108+
]
109+
)
110+
raise ValueError(f"Unknown beta_schedule {beta_schedule}.")
111+
112+
113+
def downsample_beta_schedule(
114+
betas: jnp.ndarray,
115+
num_timesteps: int,
116+
num_timesteps_to_keep: int,
117+
) -> jnp.ndarray:
118+
"""Downsample beta schedule.
119+
120+
Args:
121+
betas: beta schedule, shape (num_timesteps,).
122+
Values are in ascending order.
123+
num_timesteps: number of time steps in total, T.
124+
num_timesteps_to_keep: number of time steps to keep.
125+
126+
Returns:
127+
Downsampled beta schedule, shape (num_timesteps_to_keep,).
128+
"""
129+
if betas.shape != (num_timesteps,):
130+
raise ValueError(
131+
f"betas.shape ({betas.shape}) must be equal to "
132+
f"(num_timesteps,)=({num_timesteps},)"
133+
)
134+
if (num_timesteps - 1) % (num_timesteps_to_keep - 1) != 0:
135+
raise ValueError(
136+
f"num_timesteps-1={num_timesteps-1} can't be evenly divided by "
137+
f"num_timesteps_to_keep-1={num_timesteps_to_keep-1}."
138+
)
139+
if num_timesteps_to_keep < 2:
140+
raise ValueError(
141+
f"num_timesteps_to_keep ({num_timesteps_to_keep}) must be >= 2."
142+
)
143+
if num_timesteps_to_keep == num_timesteps:
144+
return betas
145+
if num_timesteps_to_keep < num_timesteps:
146+
step_scale = (num_timesteps - 1) // (num_timesteps_to_keep - 1)
147+
beta0 = betas[0]
148+
alphas = 1.0 - betas
149+
alphas_cumprod = jnp.cumprod(alphas)
150+
# (num_timesteps_to_keep,)
151+
alphas_cumprod = alphas_cumprod[::step_scale]
152+
# (num_timesteps_to_keep-1,)
153+
betas = 1.0 - alphas_cumprod[1:] / alphas_cumprod[:-1]
154+
# (num_timesteps_to_keep,)
155+
betas = jnp.append(beta0, betas)
156+
return betas
157+
raise ValueError(
158+
f"num_timesteps_to_keep ({num_timesteps_to_keep}) "
159+
f"must be <= num_timesteps ({num_timesteps})"
160+
)

0 commit comments

Comments
 (0)