Skip to content

Commit fe6d45a

Browse files
author
Hackable Diffusion Authors
committed
Implement SelfConditioningDiffusionNetwork.
PiperOrigin-RevId: 890391263
1 parent 0805a54 commit fe6d45a

File tree

2 files changed

+316
-0
lines changed

2 files changed

+316
-0
lines changed

hackable_diffusion/lib/diffusion_network.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,120 @@ def __call__(
190190
return {self.prediction_type: backbone_outputs}
191191

192192

193+
################################################################################
194+
# MARK: Self-Conditioning Diffusion Network
195+
################################################################################
196+
197+
198+
class SelfConditioningDiffusionNetwork(DiffusionNetwork):
199+
"""DiffusionNetwork with self-conditioning on x₀ predictions.
200+
201+
Implements self-conditioning on x₀ predictions from the discrete diffusion
202+
literature (Chen et al. "Analog Bits"; Strudel et al.; D3PM).
203+
204+
During training, with probability ``self_cond_prob`` (default 0.5):
205+
1. Run the network once with zero x̂₀ input to get initial logits.
206+
2. ``stop_gradient`` on the initial logits.
207+
3. Concatenate the logits to the noisy input xₜ along the last axis.
208+
4. Run the network again and return the output.
209+
210+
During inference (``is_training=False``), self-conditioning is always applied.
211+
212+
The ``backbone_network`` is expected to accept the wider input (xₜ
213+
concatenated with x̂₀ logits on the last axis).
214+
215+
Attributes:
216+
num_output_classes: Number of output classes (categories) for the
217+
prediction. Used to create zero-filled logits of the correct shape for
218+
the first forward pass.
219+
self_cond_prob: Probability of applying self-conditioning during training.
220+
During inference, self-conditioning is always applied.
221+
rng_collection: The PRNG collection name to use for drawing the
222+
self-conditioning mask. Defaults to 'dropout'.
223+
"""
224+
225+
num_output_classes: int = -1
226+
self_cond_prob: float = 0.5
227+
rng_collection: str = 'dropout'
228+
229+
@nn.compact
230+
@kt.typechecked
231+
def __call__(
232+
self,
233+
time: TimeArray,
234+
xt: DataArray,
235+
conditioning: Conditioning | None,
236+
is_training: bool,
237+
) -> TargetInfo:
238+
239+
if self.num_output_classes <= 0:
240+
raise ValueError(
241+
'`num_output_classes` must be a positive integer, '
242+
f'got {self.num_output_classes}.'
243+
)
244+
245+
# Rescale time and encode conditioning once for both passes.
246+
time_rescaled = (
247+
self.time_rescaler(time) if self.time_rescaler is not None else time
248+
)
249+
250+
conditioning_embeddings = self.conditioning_encoder.copy(
251+
name='ConditioningEncoder'
252+
)(
253+
time=time_rescaled,
254+
conditioning=conditioning,
255+
is_training=is_training,
256+
)
257+
258+
# Create zero logits with the same spatial shape as xt.
259+
zero_logits = jnp.zeros(
260+
xt.shape[:-1] + (self.num_output_classes,), dtype=xt.dtype
261+
)
262+
263+
# First pass: run with zero logits to get initial x̂₀.
264+
xt_with_zeros = jnp.concatenate([xt, zero_logits], axis=-1)
265+
xt_with_zeros_rescaled = (
266+
self.input_rescaler(time, xt_with_zeros)
267+
if self.input_rescaler is not None
268+
else xt_with_zeros
269+
)
270+
271+
backbone_module = self.backbone_network.copy(name='Backbone')
272+
first_output = backbone_module(
273+
x=xt_with_zeros_rescaled,
274+
conditioning_embeddings=conditioning_embeddings,
275+
is_training=is_training,
276+
)
277+
278+
# Extract logits and detach gradients.
279+
x0_hat_logits = jax.lax.stop_gradient(first_output)
280+
281+
if is_training:
282+
# With probability self_cond_prob, run self-conditioning.
283+
do_self_cond = (
284+
jax.random.uniform(self.make_rng(self.rng_collection))
285+
< self.self_cond_prob
286+
)
287+
# Conditionally use the self-conditioned logits or zeros.
288+
x0_hat_logits = jnp.where(do_self_cond, x0_hat_logits, zero_logits)
289+
290+
# Second pass: run with x̂₀ logits concatenated.
291+
xt_with_x0_hat = jnp.concatenate([xt, x0_hat_logits], axis=-1)
292+
xt_with_x0_hat_rescaled = (
293+
self.input_rescaler(time, xt_with_x0_hat)
294+
if self.input_rescaler is not None
295+
else xt_with_x0_hat
296+
)
297+
298+
backbone_outputs = backbone_module(
299+
x=xt_with_x0_hat_rescaled,
300+
conditioning_embeddings=conditioning_embeddings,
301+
is_training=is_training,
302+
)
303+
304+
return {self.prediction_type: backbone_outputs}
305+
306+
193307
################################################################################
194308
# MARK: Multi-modal Diffusion Network
195309
################################################################################

hackable_diffusion/lib/diffusion_network_test.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Tests for diffusion_network and its components."""
1616

17+
from collections.abc import Mapping
18+
1719
import chex
1820
from flax import linen as nn
1921
from hackable_diffusion.lib import diffusion_network
@@ -375,5 +377,205 @@ def test_multimodal_diffusion_network(self, input_type: str):
375377
chex.assert_trees_all_equal_structs(modified_t, output)
376378

377379

380+
################################################################################
381+
# MARK: SelfConditioningDiffusionNetwork Tests
382+
################################################################################
383+
384+
385+
class SelfConditioningBackbone(arch_typing.ConditionalBackbone):
386+
"""Backbone for self-conditioning tests.
387+
388+
Accepts input of shape (B, ..., input_channels + num_classes) and returns
389+
output of shape (B, ..., num_classes). The backbone simply applies a dense
390+
layer so the output depends on the input content.
391+
"""
392+
393+
num_classes: int = 4
394+
395+
@nn.compact
396+
def __call__(
397+
self,
398+
x: arch_typing.DataTree,
399+
conditioning_embeddings: Mapping[
400+
arch_typing.ConditioningMechanism, Float['batch ...']
401+
],
402+
is_training: bool,
403+
) -> arch_typing.DataTree:
404+
return nn.Dense(features=self.num_classes)(x)
405+
406+
407+
class SelfConditioningDiffusionNetworkTest(parameterized.TestCase):
408+
409+
def setUp(self):
410+
super().setUp()
411+
self.key = jax.random.PRNGKey(0)
412+
self.batch_size = 2
413+
self.input_channels = 1
414+
self.num_output_classes = 4
415+
self.spatial_shape = (8, 8)
416+
self.xt_shape = (
417+
self.batch_size, *self.spatial_shape, self.input_channels
418+
)
419+
self.t = jnp.ones((self.batch_size,))
420+
self.xt = jnp.ones(self.xt_shape)
421+
self.conditioning = {
422+
'label1': jnp.arange(self.batch_size),
423+
}
424+
425+
self.time_encoder = conditioning_encoder.SinusoidalTimeEmbedder(
426+
activation='silu',
427+
embedding_dim=16,
428+
num_features=32,
429+
)
430+
self.cond_encoder = conditioning_encoder.ConditioningEncoder(
431+
time_embedder=self.time_encoder,
432+
conditioning_embedders={
433+
'label': conditioning_encoder.LabelEmbedder(
434+
conditioning_key='label1',
435+
num_classes=10,
436+
num_features=16,
437+
),
438+
},
439+
embedding_merging_method=arch_typing.EmbeddingMergeMethod.CONCAT,
440+
conditioning_rules={
441+
'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,
442+
'label': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,
443+
},
444+
)
445+
self.backbone = SelfConditioningBackbone(
446+
num_classes=self.num_output_classes
447+
)
448+
449+
def _make_network(
450+
self, self_cond_prob: float = 0.5
451+
) -> diffusion_network.SelfConditioningDiffusionNetwork:
452+
return diffusion_network.SelfConditioningDiffusionNetwork(
453+
backbone_network=self.backbone,
454+
conditioning_encoder=self.cond_encoder,
455+
prediction_type='logits',
456+
data_dtype=jnp.float32,
457+
num_output_classes=self.num_output_classes,
458+
self_cond_prob=self_cond_prob,
459+
)
460+
461+
def test_output_shape(self):
462+
network = self._make_network()
463+
variables = network.init(
464+
{'params': self.key, 'dropout': self.key},
465+
self.t, self.xt, self.conditioning, True,
466+
)
467+
468+
output = network.apply(
469+
variables,
470+
self.t, self.xt, self.conditioning, True,
471+
rngs={'dropout': self.key},
472+
)
473+
474+
self.assertIsInstance(output, dict)
475+
self.assertIn('logits', output)
476+
477+
expected_shape = (
478+
self.batch_size, *self.spatial_shape, self.num_output_classes
479+
)
480+
481+
self.assertEqual(output['logits'].shape, expected_shape)
482+
483+
def test_self_cond_prob_zero_skips_self_cond(self):
484+
network = self._make_network(self_cond_prob=0.0)
485+
variables = network.init(
486+
{'params': self.key, 'dropout': self.key},
487+
self.t, self.xt, self.conditioning, True,
488+
)
489+
490+
output_no_sc = network.apply(
491+
variables,
492+
self.t, self.xt, self.conditioning, True,
493+
rngs={'dropout': self.key},
494+
)
495+
# With prob=1.0, self-conditioning is always applied (different output).
496+
network_always = self._make_network(self_cond_prob=1.0)
497+
output_always = network_always.apply(
498+
variables,
499+
self.t,
500+
self.xt,
501+
self.conditioning,
502+
True,
503+
rngs={'dropout': self.key},
504+
)
505+
self.assertFalse(
506+
jnp.allclose(output_no_sc['logits'], output_always['logits']),
507+
msg='Outputs should differ since self-conditioning changes the input.',
508+
)
509+
510+
def test_self_cond_prob_one_always_self_conditions(self):
511+
network = self._make_network(self_cond_prob=1.0)
512+
variables = network.init(
513+
{'params': self.key, 'dropout': self.key},
514+
self.t, self.xt, self.conditioning, True,
515+
)
516+
# Run twice with different RNG — should give the same result since
517+
# self_cond_prob=1.0 means the random draw has no effect.
518+
output_a = network.apply(
519+
variables,
520+
self.t, self.xt, self.conditioning, True,
521+
rngs={'dropout': jax.random.PRNGKey(42)},
522+
)
523+
output_b = network.apply(
524+
variables,
525+
self.t, self.xt, self.conditioning, True,
526+
rngs={'dropout': jax.random.PRNGKey(99)},
527+
)
528+
529+
chex.assert_trees_all_close(output_a, output_b)
530+
531+
def test_inference_always_self_conditions(self):
532+
# Even with self_cond_prob=0.0, inference should self-condition.
533+
network = self._make_network(self_cond_prob=0.0)
534+
variables = network.init(
535+
{'params': self.key, 'dropout': self.key},
536+
self.t, self.xt, self.conditioning, True,
537+
)
538+
# Inference output (is_training=False).
539+
output_infer = network.apply(
540+
variables,
541+
self.t, self.xt, self.conditioning, False,
542+
)
543+
# Training with self_cond_prob=1.0 should match inference.
544+
network_always = self._make_network(self_cond_prob=1.0)
545+
546+
output_train_sc = network_always.apply(
547+
variables,
548+
self.t, self.xt, self.conditioning, True,
549+
rngs={'dropout': self.key},
550+
)
551+
552+
chex.assert_trees_all_close(output_infer, output_train_sc)
553+
554+
def test_default_self_cond_prob(self):
555+
network = diffusion_network.SelfConditioningDiffusionNetwork(
556+
backbone_network=self.backbone,
557+
conditioning_encoder=self.cond_encoder,
558+
prediction_type='logits',
559+
num_output_classes=self.num_output_classes,
560+
)
561+
562+
self.assertEqual(network.self_cond_prob, 0.5)
563+
564+
def test_invalid_num_output_classes_raises(self):
565+
network = diffusion_network.SelfConditioningDiffusionNetwork(
566+
backbone_network=self.backbone,
567+
conditioning_encoder=self.cond_encoder,
568+
prediction_type='logits',
569+
)
570+
with self.assertRaisesRegex(ValueError, 'num_output_classes'):
571+
network.init(
572+
{'params': self.key, 'dropout': self.key},
573+
self.t,
574+
self.xt,
575+
self.conditioning,
576+
True,
577+
)
578+
579+
378580
if __name__ == '__main__':
379581
absltest.main()

0 commit comments

Comments
 (0)