Skip to content

Commit 111271b

Browse files
vdebortoHackable Diffusion Authors
authored andcommitted
RFM: Riemannian Flow Matching Architectures
PiperOrigin-RevId: 884062709
1 parent 0805a54 commit 111271b

File tree

5 files changed

+1283
-3
lines changed

5 files changed

+1283
-3
lines changed

hackable_diffusion/lib/architecture/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from hackable_diffusion.lib.architecture.mlp_blocks import MLP
5050
from hackable_diffusion.lib.architecture.normalization import NormalizationLayer
5151
from hackable_diffusion.lib.architecture.normalization import NormalizationLayerFactory
52+
from hackable_diffusion.lib.architecture.riemannian import RiemannianConditionalBackbone
5253
from hackable_diffusion.lib.architecture.sequence_embedders import RandomFourierSequenceEmbedding
5354
from hackable_diffusion.lib.architecture.sequence_embedders import RoPESequenceEmbedding
5455
from hackable_diffusion.lib.architecture.sequence_embedders import SinusoidalSequenceEmbedding
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2026 Hackable Diffusion Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Riemannian Flow Matching architectures."""
16+
17+
import flax.linen as nn
18+
from hackable_diffusion.lib import manifolds
19+
from hackable_diffusion.lib.architecture import arch_typing
20+
21+
################################################################################
22+
# MARK: Riemannian Conditional Backbone
23+
################################################################################
24+
25+
ConditionalBackbone = arch_typing.ConditionalBackbone
26+
27+
28+
class RiemannianConditionalBackbone(ConditionalBackbone):
29+
"""Velocity model for Riemannian Flow Matching.
30+
31+
Projects the output of a backbone network to the tangent space of a manifold.
32+
"""
33+
34+
backbone: ConditionalBackbone
35+
manifold: manifolds.Manifold
36+
37+
@nn.compact
38+
def __call__(self, x, conditioning_embeddings, is_training=True):
39+
40+
v = self.backbone(x, conditioning_embeddings, is_training=is_training)
41+
42+
# Project v to tangent space at xt.
43+
if isinstance(v, dict) and 'velocity' in v:
44+
v = v['velocity']
45+
46+
v_proj = self.manifold.project(x, v)
47+
return v_proj
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2026 Hackable Diffusion Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for Riemannian Flow Matching architectures."""
16+
17+
from absl.testing import absltest
18+
from hackable_diffusion.lib import manifolds
19+
from hackable_diffusion.lib.architecture import arch_typing
20+
from hackable_diffusion.lib.architecture import mlp
21+
from hackable_diffusion.lib.architecture import riemannian
22+
import jax
23+
import jax.numpy as jnp
24+
25+
26+
class RiemannianArchitectureTest(absltest.TestCase):
27+
28+
def test_riemannian_backbone_projection(self):
29+
manifold = manifolds.Sphere()
30+
backbone = mlp.ConditionalMLP(
31+
hidden_sizes_preprocess=(16,),
32+
hidden_sizes_postprocess=(16,),
33+
activation='relu',
34+
zero_init_output=True,
35+
dropout_rate=0.0,
36+
conditioning_mechanism=mlp.ConditioningMechanism.CONCATENATE,
37+
)
38+
model = riemannian.RiemannianConditionalBackbone(
39+
backbone=backbone,
40+
manifold=manifold,
41+
)
42+
43+
key = jax.random.PRNGKey(0)
44+
xt = manifold.random_uniform(key, (4, 3))
45+
time_emb = jnp.array([[0.5], [0.5], [0.5], [0.5]])
46+
47+
# conditioning_embeddings must be a dict keyed by ConditioningMechanism.
48+
conditioning_embeddings = {
49+
arch_typing.ConditioningMechanism.CONCATENATE: time_emb,
50+
}
51+
52+
variables = model.init(key, xt, conditioning_embeddings, is_training=False)
53+
v = model.apply(variables, xt, conditioning_embeddings, is_training=False)
54+
55+
self.assertEqual(v.shape, (4, 3))
56+
57+
# Check that v is tangent to xt
58+
inner_products = jnp.sum(xt * v, axis=-1)
59+
# Project should ensure dot(xt, v) = 0 for sphere
60+
self.assertAlmostEqual(jnp.max(jnp.abs(inner_products)), 0.0, places=5)
61+
62+
63+
if __name__ == '__main__':
64+
absltest.main()

hackable_diffusion/lib/sampling/riemannian_sampling_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _make_sampler(manifold):
3838
class RiemannianFlowSamplerStepTest(absltest.TestCase):
3939

4040
def test_update_sphere(self):
41-
"""Euler step on moves along geodesic."""
41+
"""Euler step on S2 moves along geodesic."""
4242
manifold = manifolds.Sphere()
4343
sampler = _make_sampler(manifold)
4444
key = jax.random.PRNGKey(0)
@@ -69,7 +69,6 @@ def test_update_so3(self):
6969
key = jax.random.PRNGKey(1)
7070

7171
xt = jnp.eye(3)[None, ...] # Identity rotation (1, 3, 3).
72-
# Tangent vector at identity is a skew-symmetric matrix.
7372
v = jnp.array([[[0.0, -0.1, 0.0], [0.1, 0.0, 0.0], [0.0, 0.0, 0.0]]])
7473

7574
current_step = base.DiffusionStep(
@@ -109,7 +108,6 @@ def test_update_torus(self):
109108
# dt = 1.0, so next_xt = exp(xt, v) = (xt + v) % 1.0.
110109
expected_xt = jnp.array([[(0.9 + 0.5) % 1.0, (0.1 - 0.5) % 1.0, 0.5]])
111110
np.testing.assert_allclose(next_step.xt, expected_xt, atol=1e-5)
112-
# Result stays in [0, 1).
113111
self.assertTrue(jnp.all(next_step.xt >= 0.0))
114112
self.assertTrue(jnp.all(next_step.xt < 1.0))
115113

hackable_diffusion/notebooks/riemannian_sphere_training.ipynb

Lines changed: 1170 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)