|
14 | 14 |
|
15 | 15 | """Tests for diffusion_network and its components.""" |
16 | 16 |
|
| 17 | +from collections.abc import Mapping |
| 18 | + |
17 | 19 | import chex |
18 | 20 | from flax import linen as nn |
19 | 21 | from hackable_diffusion.lib import diffusion_network |
@@ -375,5 +377,205 @@ def test_multimodal_diffusion_network(self, input_type: str): |
375 | 377 | chex.assert_trees_all_equal_structs(modified_t, output) |
376 | 378 |
|
377 | 379 |
|
| 380 | +################################################################################ |
| 381 | +# MARK: SelfConditioningDiffusionNetwork Tests |
| 382 | +################################################################################ |
| 383 | + |
| 384 | + |
| 385 | +class SelfConditioningBackbone(arch_typing.ConditionalBackbone): |
| 386 | + """Backbone for self-conditioning tests. |
| 387 | +
|
| 388 | + Accepts input of shape (B, ..., input_channels + num_classes) and returns |
| 389 | + output of shape (B, ..., num_classes). The backbone simply applies a dense |
| 390 | + layer so the output depends on the input content. |
| 391 | + """ |
| 392 | + |
| 393 | + num_classes: int = 4 |
| 394 | + |
| 395 | + @nn.compact |
| 396 | + def __call__( |
| 397 | + self, |
| 398 | + x: arch_typing.DataTree, |
| 399 | + conditioning_embeddings: Mapping[ |
| 400 | + arch_typing.ConditioningMechanism, Float['batch ...'] |
| 401 | + ], |
| 402 | + is_training: bool, |
| 403 | + ) -> arch_typing.DataTree: |
| 404 | + return nn.Dense(features=self.num_classes)(x) |
| 405 | + |
| 406 | + |
| 407 | +class SelfConditioningDiffusionNetworkTest(parameterized.TestCase): |
| 408 | + |
| 409 | + def setUp(self): |
| 410 | + super().setUp() |
| 411 | + self.key = jax.random.PRNGKey(0) |
| 412 | + self.batch_size = 2 |
| 413 | + self.input_channels = 1 |
| 414 | + self.num_output_classes = 4 |
| 415 | + self.spatial_shape = (8, 8) |
| 416 | + self.xt_shape = ( |
| 417 | + self.batch_size, *self.spatial_shape, self.input_channels |
| 418 | + ) |
| 419 | + self.t = jnp.ones((self.batch_size,)) |
| 420 | + self.xt = jnp.ones(self.xt_shape) |
| 421 | + self.conditioning = { |
| 422 | + 'label1': jnp.arange(self.batch_size), |
| 423 | + } |
| 424 | + |
| 425 | + self.time_encoder = conditioning_encoder.SinusoidalTimeEmbedder( |
| 426 | + activation='silu', |
| 427 | + embedding_dim=16, |
| 428 | + num_features=32, |
| 429 | + ) |
| 430 | + self.cond_encoder = conditioning_encoder.ConditioningEncoder( |
| 431 | + time_embedder=self.time_encoder, |
| 432 | + conditioning_embedders={ |
| 433 | + 'label': conditioning_encoder.LabelEmbedder( |
| 434 | + conditioning_key='label1', |
| 435 | + num_classes=10, |
| 436 | + num_features=16, |
| 437 | + ), |
| 438 | + }, |
| 439 | + embedding_merging_method=arch_typing.EmbeddingMergeMethod.CONCAT, |
| 440 | + conditioning_rules={ |
| 441 | + 'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM, |
| 442 | + 'label': arch_typing.ConditioningMechanism.ADAPTIVE_NORM, |
| 443 | + }, |
| 444 | + ) |
| 445 | + self.backbone = SelfConditioningBackbone( |
| 446 | + num_classes=self.num_output_classes |
| 447 | + ) |
| 448 | + |
| 449 | + def _make_network( |
| 450 | + self, self_cond_prob: float = 0.5 |
| 451 | + ) -> diffusion_network.SelfConditioningDiffusionNetwork: |
| 452 | + return diffusion_network.SelfConditioningDiffusionNetwork( |
| 453 | + backbone_network=self.backbone, |
| 454 | + conditioning_encoder=self.cond_encoder, |
| 455 | + prediction_type='logits', |
| 456 | + data_dtype=jnp.float32, |
| 457 | + num_output_classes=self.num_output_classes, |
| 458 | + self_cond_prob=self_cond_prob, |
| 459 | + ) |
| 460 | + |
| 461 | + def test_output_shape(self): |
| 462 | + network = self._make_network() |
| 463 | + variables = network.init( |
| 464 | + {'params': self.key, 'dropout': self.key}, |
| 465 | + self.t, self.xt, self.conditioning, True, |
| 466 | + ) |
| 467 | + |
| 468 | + output = network.apply( |
| 469 | + variables, |
| 470 | + self.t, self.xt, self.conditioning, True, |
| 471 | + rngs={'dropout': self.key}, |
| 472 | + ) |
| 473 | + |
| 474 | + self.assertIsInstance(output, dict) |
| 475 | + self.assertIn('logits', output) |
| 476 | + |
| 477 | + expected_shape = ( |
| 478 | + self.batch_size, *self.spatial_shape, self.num_output_classes |
| 479 | + ) |
| 480 | + |
| 481 | + self.assertEqual(output['logits'].shape, expected_shape) |
| 482 | + |
| 483 | + def test_self_cond_prob_zero_skips_self_cond(self): |
| 484 | + network = self._make_network(self_cond_prob=0.0) |
| 485 | + variables = network.init( |
| 486 | + {'params': self.key, 'dropout': self.key}, |
| 487 | + self.t, self.xt, self.conditioning, True, |
| 488 | + ) |
| 489 | + |
| 490 | + output_no_sc = network.apply( |
| 491 | + variables, |
| 492 | + self.t, self.xt, self.conditioning, True, |
| 493 | + rngs={'dropout': self.key}, |
| 494 | + ) |
| 495 | + # With prob=1.0, self-conditioning is always applied (different output). |
| 496 | + network_always = self._make_network(self_cond_prob=1.0) |
| 497 | + output_always = network_always.apply( |
| 498 | + variables, |
| 499 | + self.t, |
| 500 | + self.xt, |
| 501 | + self.conditioning, |
| 502 | + True, |
| 503 | + rngs={'dropout': self.key}, |
| 504 | + ) |
| 505 | + self.assertFalse( |
| 506 | + jnp.allclose(output_no_sc['logits'], output_always['logits']), |
| 507 | + msg='Outputs should differ since self-conditioning changes the input.', |
| 508 | + ) |
| 509 | + |
| 510 | + def test_self_cond_prob_one_always_self_conditions(self): |
| 511 | + network = self._make_network(self_cond_prob=1.0) |
| 512 | + variables = network.init( |
| 513 | + {'params': self.key, 'dropout': self.key}, |
| 514 | + self.t, self.xt, self.conditioning, True, |
| 515 | + ) |
| 516 | + # Run twice with different RNG — should give the same result since |
| 517 | + # self_cond_prob=1.0 means the random draw has no effect. |
| 518 | + output_a = network.apply( |
| 519 | + variables, |
| 520 | + self.t, self.xt, self.conditioning, True, |
| 521 | + rngs={'dropout': jax.random.PRNGKey(42)}, |
| 522 | + ) |
| 523 | + output_b = network.apply( |
| 524 | + variables, |
| 525 | + self.t, self.xt, self.conditioning, True, |
| 526 | + rngs={'dropout': jax.random.PRNGKey(99)}, |
| 527 | + ) |
| 528 | + |
| 529 | + chex.assert_trees_all_close(output_a, output_b) |
| 530 | + |
| 531 | + def test_inference_always_self_conditions(self): |
| 532 | + # Even with self_cond_prob=0.0, inference should self-condition. |
| 533 | + network = self._make_network(self_cond_prob=0.0) |
| 534 | + variables = network.init( |
| 535 | + {'params': self.key, 'dropout': self.key}, |
| 536 | + self.t, self.xt, self.conditioning, True, |
| 537 | + ) |
| 538 | + # Inference output (is_training=False). |
| 539 | + output_infer = network.apply( |
| 540 | + variables, |
| 541 | + self.t, self.xt, self.conditioning, False, |
| 542 | + ) |
| 543 | + # Training with self_cond_prob=1.0 should match inference. |
| 544 | + network_always = self._make_network(self_cond_prob=1.0) |
| 545 | + |
| 546 | + output_train_sc = network_always.apply( |
| 547 | + variables, |
| 548 | + self.t, self.xt, self.conditioning, True, |
| 549 | + rngs={'dropout': self.key}, |
| 550 | + ) |
| 551 | + |
| 552 | + chex.assert_trees_all_close(output_infer, output_train_sc) |
| 553 | + |
| 554 | + def test_default_self_cond_prob(self): |
| 555 | + network = diffusion_network.SelfConditioningDiffusionNetwork( |
| 556 | + backbone_network=self.backbone, |
| 557 | + conditioning_encoder=self.cond_encoder, |
| 558 | + prediction_type='logits', |
| 559 | + num_output_classes=self.num_output_classes, |
| 560 | + ) |
| 561 | + |
| 562 | + self.assertEqual(network.self_cond_prob, 0.5) |
| 563 | + |
| 564 | + def test_invalid_num_output_classes_raises(self): |
| 565 | + network = diffusion_network.SelfConditioningDiffusionNetwork( |
| 566 | + backbone_network=self.backbone, |
| 567 | + conditioning_encoder=self.cond_encoder, |
| 568 | + prediction_type='logits', |
| 569 | + ) |
| 570 | + with self.assertRaisesRegex(ValueError, 'num_output_classes'): |
| 571 | + network.init( |
| 572 | + {'params': self.key, 'dropout': self.key}, |
| 573 | + self.t, |
| 574 | + self.xt, |
| 575 | + self.conditioning, |
| 576 | + True, |
| 577 | + ) |
| 578 | + |
| 579 | + |
378 | 580 | if __name__ == '__main__': |
379 | 581 | absltest.main() |
0 commit comments