6060SimplicialProcess = simplicial .SimplicialProcess
6161SimplicialSchedule = schedules .SimplicialSchedule
6262
63+
6364################################################################################
64- # MARK: DDIM Step
65+ # MARK: Beta Shrinkage
6566################################################################################
6667
67- # TODO(vdebortoli): Add support for churn.
68+
69+ @kt .typechecked
70+ def log_beta_shrinkage (
71+ key : jax .Array ,
72+ log_x : jax .Array ,
73+ concentration : jax .Array ,
74+ kappa : float ,
75+ ) -> jax .Array :
76+ """Beta shrinkage of a Dirichlet sample.
77+
78+ Let log(X) such that X ~ Dir(concentration) and kappa in [0, 1]. Then this
79+ function returns log(Y) such that Y ~ Dir(kappa * concentration).
80+
81+ To do so we leverage the following identity.
82+ Let B ~ Beta(a, b) with a = kappa * concentration and b = (1 - kappa) *
83+ concentration.
84+ Then B X / sum(B X) has the same distribution as Dir(kappa * concentration).
85+ We call this process "Beta-shrinkage".
86+
87+ Args:
88+ key: the random key.
89+ log_x: the log-Dirichlet sample of shape (..., num_categories).
90+ concentration: the concentration scalar or array.
91+ kappa: the shrinkage parameter in [0, 1].
92+
93+ Returns:
94+ the shrunk log-sample.
95+ """
96+ if kappa == 1.0 :
97+ return log_x
98+ alpha_vec = jnp .broadcast_to (concentration , log_x .shape )
99+
100+ log_b , _ = random_utils .sample_log_beta_joint (
101+ key , kappa * alpha_vec , (1.0 - kappa ) * alpha_vec , shape = alpha_vec .shape
102+ )
103+
104+ log_y = log_b + log_x
105+ log_y = log_y - jax .nn .logsumexp (log_y , axis = - 1 , keepdims = True )
106+ return log_y
107+
108+
109+ ################################################################################
110+ # MARK: DDIM Step
111+ ################################################################################
68112
69113
70114@dataclasses .dataclass (frozen = True , kw_only = True )
71115class SimplicialDDIMStep (SamplerStep ):
72116 """This is the simplicial version of the DDIM step."""
73117
74118 corruption_process : SimplicialProcess
119+ churn : float = 1.0
75120
76121 @kt .typechecked
77122 def initialize (
@@ -84,8 +129,6 @@ def initialize(
84129 step_info = initial_step_info ,
85130 aux = {'logits' : initial_noise },
86131 )
87- # `logits` need to be passed in `aux` dictionary to a performance
88- # bug when using TPU. Needs to be investigated.
89132
90133 @kt .typechecked
91134 def update (
@@ -114,31 +157,63 @@ def update(
114157 )['logits' ]
115158
116159 # Sample hard token
117- sample_key , beta_key = jax .random .split (key )
160+ key , sample_key = jax .random .split (key )
118161 sample_idx = jax .random .categorical (key = sample_key , logits = logits )
119162 num_cats = self .corruption_process .process_num_categories
120163 one_hot_mask = jax .nn .one_hot (sample_idx , num_cats , dtype = log_xt .dtype )
121164 log_sample_oh = jnp .where (one_hot_mask > 0.5 , 0.0 , - 1e30 )
122165
123- # Compute Beta shape parameters
166+ # Compute parameters
167+ eps = self .corruption_process .temperature
124168 alpha_t = self .corruption_process .schedule .alpha (time )
125169 alpha_s = self .corruption_process .schedule .alpha (next_time )
126170
127- shape_0 = self . corruption_process . temperature / (1.0 - alpha_t )
128- shape_1 = self . corruption_process . temperature / (1.0 - alpha_s ) - shape_0
171+ bar_beta_t = eps / (1.0 - alpha_t )
172+ bar_beta_s = eps / (1.0 - alpha_s )
129173
130- # Broadcasting
131174 target_shape = log_xt .shape [:- 1 ] + (1 ,)
132- shape_0 = jnp .broadcast_to (shape_0 , target_shape )
133- shape_1 = jnp .broadcast_to (shape_1 , target_shape )
134-
135- # Sample from Beta(shape_0, shape_1)
136- log_w , log_1_minus_w = random_utils .sample_log_beta_joint (
137- beta_key , shape_0 , shape_1 , shape = shape_0 .shape
138- )
139175
140- term_1 = log_w + log_xt
141- term_2 = log_1_minus_w + log_sample_oh
176+ if self .churn == 0.0 :
177+ # Regular DDIM step
178+ shape_0 = bar_beta_t
179+ shape_1 = bar_beta_s - shape_0
180+
181+ _ , beta_key = jax .random .split (key )
182+ log_w , log_1_minus_w = random_utils .sample_log_beta_joint (
183+ beta_key , shape_0 , shape_1 , shape = target_shape
184+ )
185+
186+ term_1 = log_w + log_xt
187+ term_2 = log_1_minus_w + log_sample_oh
188+ else :
189+ # Churn step
190+ pi = self .corruption_process .invariant_probs_vec
191+ h_t = alpha_t / (1.0 - alpha_t )
192+ h_s = alpha_s / (1.0 - alpha_s )
193+
194+ key , f_key = jax .random .split (key )
195+ log_pt_kappa = log_beta_shrinkage (
196+ f_key , log_x = log_xt , concentration = bar_beta_t , kappa = 1.0 - self .churn
197+ )
198+
199+ key , v_key = jax .random .split (key )
200+ alpha_v = (
201+ self .churn * eps * pi
202+ + (eps * h_s - (1.0 - self .churn ) * eps * h_t ) * one_hot_mask
203+ )
204+ log_v = random_utils .log_dirichlet_fast (v_key , alpha = alpha_v , shape = ())
205+
206+ # Sample W from Beta(kappa * bar_beta_t, bar_beta_s - kappa * bar_beta_t)
207+ _ , beta_key = jax .random .split (key )
208+ log_w , log_1_minus_w = random_utils .sample_log_beta_joint (
209+ beta_key ,
210+ (1.0 - self .churn ) * bar_beta_t ,
211+ bar_beta_s - (1.0 - self .churn ) * bar_beta_t ,
212+ shape = target_shape ,
213+ )
214+
215+ term_1 = log_w + log_pt_kappa
216+ term_2 = log_1_minus_w + log_v
142217
143218 new_xt = jnp .logaddexp (term_1 , term_2 )
144219
@@ -147,8 +222,6 @@ def update(
147222 step_info = next_step_info ,
148223 aux = {'logits' : logits },
149224 )
150- # `logits` need to be passed in `aux` dictionary to a performance
151- # bug when using TPU. Needs to be investigated.
152225
153226 @kt .typechecked
154227 def finalize (
0 commit comments