Skip to content

Commit 19bf5a0

Browse files
vdebortoHackable Diffusion Authors
authored andcommitted
RFM: Riemannian Flow Matching Architectures
PiperOrigin-RevId: 884062709
1 parent 0805a54 commit 19bf5a0

File tree

7 files changed

+1336
-17
lines changed

7 files changed

+1336
-17
lines changed

hackable_diffusion/lib/architecture/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from hackable_diffusion.lib.architecture.mlp_blocks import MLP
5050
from hackable_diffusion.lib.architecture.normalization import NormalizationLayer
5151
from hackable_diffusion.lib.architecture.normalization import NormalizationLayerFactory
52+
from hackable_diffusion.lib.architecture.riemannian import RiemannianConditionalBackbone
5253
from hackable_diffusion.lib.architecture.sequence_embedders import RandomFourierSequenceEmbedding
5354
from hackable_diffusion.lib.architecture.sequence_embedders import RoPESequenceEmbedding
5455
from hackable_diffusion.lib.architecture.sequence_embedders import SinusoidalSequenceEmbedding
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2026 Hackable Diffusion Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Riemannian Flow Matching architectures."""
16+
17+
import flax.linen as nn
18+
from hackable_diffusion.lib import manifolds
19+
from hackable_diffusion.lib.architecture import arch_typing
20+
21+
################################################################################
22+
# MARK: Riemannian Conditional Backbone
23+
################################################################################
24+
25+
ConditionalBackbone = arch_typing.ConditionalBackbone
26+
27+
28+
class RiemannianConditionalBackbone(ConditionalBackbone):
29+
"""Velocity model for Riemannian Flow Matching.
30+
31+
Projects the output of a backbone network to the tangent space of a manifold.
32+
"""
33+
34+
backbone: ConditionalBackbone
35+
manifold: manifolds.Manifold
36+
37+
@nn.compact
38+
def __call__(self, x, conditioning_embeddings, is_training=True):
39+
40+
v = self.backbone(x, conditioning_embeddings, is_training=is_training)
41+
42+
# Project v to tangent space at xt.
43+
if isinstance(v, dict) and 'velocity' in v:
44+
v = v['velocity']
45+
46+
v_proj = self.manifold.project(x, v)
47+
return v_proj
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2026 Hackable Diffusion Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for Riemannian Flow Matching architectures."""
16+
17+
from absl.testing import absltest
18+
from hackable_diffusion.lib import manifolds
19+
from hackable_diffusion.lib.architecture import arch_typing
20+
from hackable_diffusion.lib.architecture import mlp
21+
from hackable_diffusion.lib.architecture import riemannian
22+
import jax
23+
import jax.numpy as jnp
24+
25+
26+
class RiemannianArchitectureTest(absltest.TestCase):
27+
28+
def test_riemannian_backbone_projection(self):
29+
manifold = manifolds.Sphere()
30+
backbone = mlp.ConditionalMLP(
31+
hidden_sizes_preprocess=(16,),
32+
hidden_sizes_postprocess=(16,),
33+
activation='relu',
34+
zero_init_output=True,
35+
dropout_rate=0.0,
36+
conditioning_mechanism=mlp.ConditioningMechanism.CONCATENATE,
37+
)
38+
model = riemannian.RiemannianConditionalBackbone(
39+
backbone=backbone,
40+
manifold=manifold,
41+
)
42+
43+
key = jax.random.PRNGKey(0)
44+
xt = manifold.random_uniform(key, (4, 3))
45+
time_emb = jnp.array([[0.5], [0.5], [0.5], [0.5]])
46+
47+
# conditioning_embeddings must be a dict keyed by ConditioningMechanism.
48+
conditioning_embeddings = {
49+
arch_typing.ConditioningMechanism.CONCATENATE: time_emb,
50+
}
51+
52+
variables = model.init(key, xt, conditioning_embeddings, is_training=False)
53+
v = model.apply(variables, xt, conditioning_embeddings, is_training=False)
54+
55+
self.assertEqual(v.shape, (4, 3))
56+
57+
# Check that v is tangent to xt
58+
inner_products = jnp.sum(xt * v, axis=-1)
59+
# Project should ensure dot(xt, v) = 0 for sphere
60+
self.assertAlmostEqual(jnp.max(jnp.abs(inner_products)), 0.0, places=5)
61+
62+
63+
if __name__ == '__main__':
64+
absltest.main()

hackable_diffusion/lib/corruption/riemannian.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ def corrupt(
6767
alpha_dot_t = utils.bcast_right(self.schedule.alpha_dot(time), x0.ndim)
6868

6969
# x_t = geodesic(x0, x1, alpha(t)).
70-
xt = self.manifold.exp(x0, alpha_t * self.manifold.log(x0, x1))
70+
xt = manifolds.geodesic(self.manifold, x1, x0, alpha_t)
7171

7272
# By chain rule: d/dt x_t = alpha'(t) * velocity(x0, x1, alpha(t)).
73-
vel = alpha_dot_t * self.manifold.velocity(x0, x1, alpha_t)
73+
vel = alpha_dot_t * self.manifold.velocity(x1, x0, alpha_t)
7474

7575
target_info = {
7676
'x0': x0,

hackable_diffusion/lib/corruption/riemannian_test.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,21 @@ def test_corrupt(self):
5454
np.testing.assert_allclose(inner_products, 0.0, atol=1e-5)
5555

5656
def test_velocity_at_t1(self):
57-
"""At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1)."""
57+
"""At t=1, alpha=0 so xt = x1 and velocity = -log(x1, x0)."""
5858
manifold = manifolds.Sphere()
5959
process = _make_process(manifold)
6060
key = jax.random.PRNGKey(0)
6161

6262
x0 = jnp.array([[1.0, 0.0, 0.0]])
6363
t1 = jnp.array([1.0])
6464
xt1, target1 = process.corrupt(key, x0, t1)
65-
np.testing.assert_allclose(xt1, x0, atol=1e-5)
65+
x1_sampled = target1['x1']
66+
# At t=1, alpha=0: geodesic(x1, x0, 0) = x1.
67+
np.testing.assert_allclose(xt1, x1_sampled, atol=1e-5)
6668

69+
# velocity = alpha_dot(1) * velocity(x1, x0, 0) = -1 * log(x1, x0).
6770
v1 = target1['velocity']
68-
x1_sampled = target1['x1']
69-
v_log = manifold.log(x0, x1_sampled)
71+
v_log = manifold.log(x1_sampled, x0)
7072
np.testing.assert_allclose(v1, -v_log, atol=1e-5)
7173

7274

@@ -94,19 +96,21 @@ def test_corrupt(self):
9496
self.assertEqual(vel.shape, (batch_size, 3, 3))
9597

9698
def test_velocity_at_t1(self):
97-
"""At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1)."""
99+
"""At t=1, alpha=0 so xt = x1 and velocity = -log(x1, x0)."""
98100
manifold = manifolds.SO3()
99101
process = _make_process(manifold)
100102
key = jax.random.PRNGKey(1)
101103

102104
x0 = jnp.eye(3)[None, ...] # (1, 3, 3)
103105
t1 = jnp.array([1.0])
104106
xt1, target1 = process.corrupt(key, x0, t1)
105-
np.testing.assert_allclose(xt1, x0, atol=1e-5)
107+
x1_sampled = target1['x1']
108+
# At t=1, alpha=0: geodesic(x1, x0, 0) = x1.
109+
np.testing.assert_allclose(xt1, x1_sampled, atol=1e-5)
106110

111+
# velocity = alpha_dot(1) * velocity(x1, x0, 0) = -1 * log(x1, x0).
107112
v1 = target1['velocity']
108-
x1_sampled = target1['x1']
109-
v_log = manifold.log(x0, x1_sampled)
113+
v_log = manifold.log(x1_sampled, x0)
110114
np.testing.assert_allclose(v1, -v_log, atol=1e-4)
111115

112116

@@ -132,19 +136,21 @@ def test_corrupt(self):
132136
self.assertEqual(vel.shape, (batch_size, dim))
133137

134138
def test_velocity_at_t1(self):
135-
"""At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1)."""
139+
"""At t=1, alpha=0 so xt = x1 and velocity = -log(x1, x0)."""
136140
manifold = manifolds.Torus()
137141
process = _make_process(manifold)
138142
key = jax.random.PRNGKey(2)
139143

140144
x0 = jnp.array([[0.1, 0.5, 0.9]])
141145
t1 = jnp.array([1.0])
142146
xt1, target1 = process.corrupt(key, x0, t1)
143-
np.testing.assert_allclose(xt1, x0, atol=1e-5)
147+
x1_sampled = target1['x1']
148+
# At t=1, alpha=0: geodesic(x1, x0, 0) = x1.
149+
np.testing.assert_allclose(xt1, x1_sampled, atol=1e-5)
144150

151+
# velocity = alpha_dot(1) * velocity(x1, x0, 0) = -1 * log(x1, x0).
145152
v1 = target1['velocity']
146-
x1_sampled = target1['x1']
147-
v_log = manifold.log(x0, x1_sampled)
153+
v_log = manifold.log(x1_sampled, x0)
148154
np.testing.assert_allclose(v1, -v_log, atol=1e-5)
149155

150156

hackable_diffusion/lib/sampling/riemannian_sampling_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _make_sampler(manifold):
3838
class RiemannianFlowSamplerStepTest(absltest.TestCase):
3939

4040
def test_update_sphere(self):
41-
"""Euler step on moves along geodesic."""
41+
"""Euler step on S2 moves along geodesic."""
4242
manifold = manifolds.Sphere()
4343
sampler = _make_sampler(manifold)
4444
key = jax.random.PRNGKey(0)
@@ -69,7 +69,6 @@ def test_update_so3(self):
6969
key = jax.random.PRNGKey(1)
7070

7171
xt = jnp.eye(3)[None, ...] # Identity rotation (1, 3, 3).
72-
# Tangent vector at identity is a skew-symmetric matrix.
7372
v = jnp.array([[[0.0, -0.1, 0.0], [0.1, 0.0, 0.0], [0.0, 0.0, 0.0]]])
7473

7574
current_step = base.DiffusionStep(
@@ -109,7 +108,6 @@ def test_update_torus(self):
109108
# dt = 1.0, so next_xt = exp(xt, v) = (xt + v) % 1.0.
110109
expected_xt = jnp.array([[(0.9 + 0.5) % 1.0, (0.1 - 0.5) % 1.0, 0.5]])
111110
np.testing.assert_allclose(next_step.xt, expected_xt, atol=1e-5)
112-
# Result stays in [0, 1).
113111
self.assertTrue(jnp.all(next_step.xt >= 0.0))
114112
self.assertTrue(jnp.all(next_step.xt < 1.0))
115113

0 commit comments

Comments
 (0)