Skip to content

Commit 533e92c

Browse files
vdebortoHackable Diffusion Authors
authored andcommitted
Incorporate churn logic into SimplicialDDIMStep.
PiperOrigin-RevId: 889175381
1 parent 647579c commit 533e92c

File tree

2 files changed

+353
-20
lines changed

2 files changed

+353
-20
lines changed

hackable_diffusion/lib/sampling/simplicial_step_sampler.py

Lines changed: 93 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,63 @@
6060
SimplicialProcess = simplicial.SimplicialProcess
6161
SimplicialSchedule = schedules.SimplicialSchedule
6262

63+
6364
################################################################################
64-
# MARK: DDIM Step
65+
# MARK: Beta Shrinkage
6566
################################################################################
6667

67-
# TODO(vdebortoli): Add support for churn.
68+
69+
@kt.typechecked
70+
def log_beta_shrinkage(
71+
key: jax.Array,
72+
log_x: jax.Array,
73+
concentration: jax.Array,
74+
kappa: float,
75+
) -> jax.Array:
76+
"""Beta shrinkage of a Dirichlet sample.
77+
78+
Let log(X) such that X ~ Dir(concentration) and kappa in [0, 1]. Then this
79+
function returns log(Y) such that Y ~ Dir(kappa * concentration).
80+
81+
To do so we leverage the following identity.
82+
Let B ~ Beta(a, b) with a = kappa * concentration and b = (1 - kappa) *
83+
concentration.
84+
Then B X / sum(B X) has the same distribution as Dir(kappa * concentration).
85+
We call this process "Beta-shrinkage".
86+
87+
Args:
88+
key: the random key.
89+
log_x: the log-Dirichlet sample of shape (..., num_categories).
90+
concentration: the concentration scalar or array.
91+
kappa: the shrinkage parameter in [0, 1].
92+
93+
Returns:
94+
the shrunk log-sample.
95+
"""
96+
if kappa == 1.0:
97+
return log_x
98+
alpha_vec = jnp.broadcast_to(concentration, log_x.shape)
99+
100+
log_b, _ = random_utils.sample_log_beta_joint(
101+
key, kappa * alpha_vec, (1.0 - kappa) * alpha_vec, shape=alpha_vec.shape
102+
)
103+
104+
log_y = log_b + log_x
105+
log_y = log_y - jax.nn.logsumexp(log_y, axis=-1, keepdims=True)
106+
return log_y
107+
108+
109+
################################################################################
110+
# MARK: DDIM Step
111+
################################################################################
68112

69113

70114
@dataclasses.dataclass(frozen=True, kw_only=True)
71115
class SimplicialDDIMStep(SamplerStep):
72116
"""This is the simplicial version of the DDIM step."""
73117

74118
corruption_process: SimplicialProcess
119+
churn: float = 1.0
75120

76121
@kt.typechecked
77122
def initialize(
@@ -84,8 +129,6 @@ def initialize(
84129
step_info=initial_step_info,
85130
aux={'logits': initial_noise},
86131
)
87-
# `logits` need to be passed in `aux` dictionary to a performance
88-
# bug when using TPU. Needs to be investigated.
89132

90133
@kt.typechecked
91134
def update(
@@ -114,31 +157,63 @@ def update(
114157
)['logits']
115158

116159
# Sample hard token
117-
sample_key, beta_key = jax.random.split(key)
160+
key, sample_key = jax.random.split(key)
118161
sample_idx = jax.random.categorical(key=sample_key, logits=logits)
119162
num_cats = self.corruption_process.process_num_categories
120163
one_hot_mask = jax.nn.one_hot(sample_idx, num_cats, dtype=log_xt.dtype)
121164
log_sample_oh = jnp.where(one_hot_mask > 0.5, 0.0, -1e30)
122165

123-
# Compute Beta shape parameters
166+
# Compute parameters
167+
eps = self.corruption_process.temperature
124168
alpha_t = self.corruption_process.schedule.alpha(time)
125169
alpha_s = self.corruption_process.schedule.alpha(next_time)
126170

127-
shape_0 = self.corruption_process.temperature / (1.0 - alpha_t)
128-
shape_1 = self.corruption_process.temperature / (1.0 - alpha_s) - shape_0
171+
bar_beta_t = eps / (1.0 - alpha_t)
172+
bar_beta_s = eps / (1.0 - alpha_s)
129173

130-
# Broadcasting
131174
target_shape = log_xt.shape[:-1] + (1,)
132-
shape_0 = jnp.broadcast_to(shape_0, target_shape)
133-
shape_1 = jnp.broadcast_to(shape_1, target_shape)
134-
135-
# Sample from Beta(shape_0, shape_1)
136-
log_w, log_1_minus_w = random_utils.sample_log_beta_joint(
137-
beta_key, shape_0, shape_1, shape=shape_0.shape
138-
)
139175

140-
term_1 = log_w + log_xt
141-
term_2 = log_1_minus_w + log_sample_oh
176+
if self.churn == 0.0:
177+
# Regular DDIM step
178+
shape_0 = bar_beta_t
179+
shape_1 = bar_beta_s - shape_0
180+
181+
_, beta_key = jax.random.split(key)
182+
log_w, log_1_minus_w = random_utils.sample_log_beta_joint(
183+
beta_key, shape_0, shape_1, shape=target_shape
184+
)
185+
186+
term_1 = log_w + log_xt
187+
term_2 = log_1_minus_w + log_sample_oh
188+
else:
189+
# Churn step
190+
pi = self.corruption_process.invariant_probs_vec
191+
h_t = alpha_t / (1.0 - alpha_t)
192+
h_s = alpha_s / (1.0 - alpha_s)
193+
194+
key, f_key = jax.random.split(key)
195+
log_pt_kappa = log_beta_shrinkage(
196+
f_key, log_x=log_xt, concentration=bar_beta_t, kappa=1.0 - self.churn
197+
)
198+
199+
key, v_key = jax.random.split(key)
200+
alpha_v = (
201+
self.churn * eps * pi
202+
+ (eps * h_s - (1.0 - self.churn) * eps * h_t) * one_hot_mask
203+
)
204+
log_v = random_utils.log_dirichlet_fast(v_key, alpha=alpha_v, shape=())
205+
206+
# Sample W from Beta(kappa * bar_beta_t, bar_beta_s - kappa * bar_beta_t)
207+
_, beta_key = jax.random.split(key)
208+
log_w, log_1_minus_w = random_utils.sample_log_beta_joint(
209+
beta_key,
210+
(1.0 - self.churn) * bar_beta_t,
211+
bar_beta_s - (1.0 - self.churn) * bar_beta_t,
212+
shape=target_shape,
213+
)
214+
215+
term_1 = log_w + log_pt_kappa
216+
term_2 = log_1_minus_w + log_v
142217

143218
new_xt = jnp.logaddexp(term_1, term_2)
144219

@@ -147,8 +222,6 @@ def update(
147222
step_info=next_step_info,
148223
aux={'logits': logits},
149224
)
150-
# `logits` need to be passed in `aux` dictionary to a performance
151-
# bug when using TPU. Needs to be investigated.
152225

153226
@kt.typechecked
154227
def finalize(

0 commit comments

Comments
 (0)