|
| 1 | +# Copyright 2026 Hackable Diffusion Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Tests for Riemannian Flow Matching corruption process.""" |
| 16 | + |
| 17 | +from absl.testing import absltest |
| 18 | +from hackable_diffusion.lib import manifolds |
| 19 | +from hackable_diffusion.lib.corruption import riemannian |
| 20 | +from hackable_diffusion.lib.corruption import schedules |
| 21 | +import jax |
| 22 | +import jax.numpy as jnp |
| 23 | +import numpy as np |
| 24 | + |
| 25 | + |
| 26 | +def _make_process(manifold): |
| 27 | + return riemannian.RiemannianProcess( |
| 28 | + manifold=manifold, |
| 29 | + schedule=schedules.LinearRiemannianSchedule(), |
| 30 | + ) |
| 31 | + |
| 32 | + |
| 33 | +class SphereCorruptionTest(absltest.TestCase): |
| 34 | + |
| 35 | + def test_corrupt(self): |
| 36 | + manifold = manifolds.Sphere() |
| 37 | + process = _make_process(manifold) |
| 38 | + key = jax.random.PRNGKey(0) |
| 39 | + |
| 40 | + batch_size = 8 |
| 41 | + x0 = manifold.random_uniform(key, (batch_size, 3)) |
| 42 | + time = jnp.linspace(0, 1, batch_size) |
| 43 | + |
| 44 | + xt, target_info = process.corrupt(key, x0, time) |
| 45 | + |
| 46 | + # xt should be on the sphere. |
| 47 | + norms = jnp.linalg.norm(xt, axis=-1) |
| 48 | + np.testing.assert_allclose(norms, 1.0, atol=1e-5) |
| 49 | + |
| 50 | + # Velocity should be tangent to the sphere at xt, i.e. <xt, vel> = 0. |
| 51 | + vel = target_info['velocity'] |
| 52 | + self.assertEqual(vel.shape, (batch_size, 3)) |
| 53 | + inner_products = jnp.sum(xt * vel, axis=-1) |
| 54 | + np.testing.assert_allclose(inner_products, 0.0, atol=1e-5) |
| 55 | + |
| 56 | + def test_velocity_at_t1(self): |
| 57 | + """At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1).""" |
| 58 | + manifold = manifolds.Sphere() |
| 59 | + process = _make_process(manifold) |
| 60 | + key = jax.random.PRNGKey(0) |
| 61 | + |
| 62 | + x0 = jnp.array([[1.0, 0.0, 0.0]]) |
| 63 | + t1 = jnp.array([1.0]) |
| 64 | + xt1, target1 = process.corrupt(key, x0, t1) |
| 65 | + np.testing.assert_allclose(xt1, x0, atol=1e-5) |
| 66 | + |
| 67 | + v1 = target1['velocity'] |
| 68 | + x1_sampled = target1['x1'] |
| 69 | + v_log = manifold.log(x0, x1_sampled) |
| 70 | + np.testing.assert_allclose(v1, -v_log, atol=1e-5) |
| 71 | + |
| 72 | + |
| 73 | +class SO3CorruptionTest(absltest.TestCase): |
| 74 | + |
| 75 | + def test_corrupt(self): |
| 76 | + manifold = manifolds.SO3() |
| 77 | + process = _make_process(manifold) |
| 78 | + key = jax.random.PRNGKey(1) |
| 79 | + |
| 80 | + batch_size = 8 |
| 81 | + x0 = manifold.random_uniform(key, (batch_size, 3, 3)) |
| 82 | + time = jnp.linspace(0, 1, batch_size) |
| 83 | + |
| 84 | + xt, target_info = process.corrupt(key, x0, time) |
| 85 | + |
| 86 | + # xt should be a valid rotation: R^T R = I and det(R) = 1. |
| 87 | + rtrt = jnp.matmul(jnp.swapaxes(xt, -2, -1), xt) |
| 88 | + eyes = jnp.broadcast_to(jnp.eye(3), rtrt.shape) |
| 89 | + np.testing.assert_allclose(rtrt, eyes, atol=1e-5) |
| 90 | + np.testing.assert_allclose(jnp.linalg.det(xt), 1.0, atol=1e-5) |
| 91 | + |
| 92 | + # Velocity should be in the tangent space: x^T v is skew-symmetric. |
| 93 | + vel = target_info['velocity'] |
| 94 | + self.assertEqual(vel.shape, (batch_size, 3, 3)) |
| 95 | + |
| 96 | + def test_velocity_at_t1(self): |
| 97 | + """At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1).""" |
| 98 | + manifold = manifolds.SO3() |
| 99 | + process = _make_process(manifold) |
| 100 | + key = jax.random.PRNGKey(1) |
| 101 | + |
| 102 | + x0 = jnp.eye(3)[None, ...] # (1, 3, 3) |
| 103 | + t1 = jnp.array([1.0]) |
| 104 | + xt1, target1 = process.corrupt(key, x0, t1) |
| 105 | + np.testing.assert_allclose(xt1, x0, atol=1e-5) |
| 106 | + |
| 107 | + v1 = target1['velocity'] |
| 108 | + x1_sampled = target1['x1'] |
| 109 | + v_log = manifold.log(x0, x1_sampled) |
| 110 | + np.testing.assert_allclose(v1, -v_log, atol=1e-4) |
| 111 | + |
| 112 | + |
| 113 | +class TorusCorruptionTest(absltest.TestCase): |
| 114 | + |
| 115 | + def test_corrupt(self): |
| 116 | + manifold = manifolds.Torus() |
| 117 | + process = _make_process(manifold) |
| 118 | + key = jax.random.PRNGKey(2) |
| 119 | + |
| 120 | + batch_size = 8 |
| 121 | + dim = 4 |
| 122 | + x0 = manifold.random_uniform(key, (batch_size, dim)) |
| 123 | + time = jnp.linspace(0, 1, batch_size) |
| 124 | + |
| 125 | + xt, target_info = process.corrupt(key, x0, time) |
| 126 | + |
| 127 | + # xt should be in [0, 1). |
| 128 | + self.assertTrue(jnp.all(xt >= 0.0)) |
| 129 | + self.assertTrue(jnp.all(xt < 1.0)) |
| 130 | + |
| 131 | + vel = target_info['velocity'] |
| 132 | + self.assertEqual(vel.shape, (batch_size, dim)) |
| 133 | + |
| 134 | + def test_velocity_at_t1(self): |
| 135 | + """At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1).""" |
| 136 | + manifold = manifolds.Torus() |
| 137 | + process = _make_process(manifold) |
| 138 | + key = jax.random.PRNGKey(2) |
| 139 | + |
| 140 | + x0 = jnp.array([[0.1, 0.5, 0.9]]) |
| 141 | + t1 = jnp.array([1.0]) |
| 142 | + xt1, target1 = process.corrupt(key, x0, t1) |
| 143 | + np.testing.assert_allclose(xt1, x0, atol=1e-5) |
| 144 | + |
| 145 | + v1 = target1['velocity'] |
| 146 | + x1_sampled = target1['x1'] |
| 147 | + v_log = manifold.log(x0, x1_sampled) |
| 148 | + np.testing.assert_allclose(v1, -v_log, atol=1e-5) |
| 149 | + |
| 150 | + |
| 151 | +if __name__ == '__main__': |
| 152 | + absltest.main() |
0 commit comments