Skip to content

Commit b8ff02c

Browse files
vdebortoHackable Diffusion Authors
authored andcommitted
RFM: Riemannian Flow Matching Process
PiperOrigin-RevId: 884061391
1 parent ab2913b commit b8ff02c

File tree

9 files changed

+670
-95
lines changed

9 files changed

+670
-95
lines changed

hackable_diffusion/lib/corruption/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from hackable_diffusion.lib.corruption.discrete import PostCorruptionFn
2323
from hackable_diffusion.lib.corruption.discrete import SymmetricPostCorruptionFn
2424
from hackable_diffusion.lib.corruption.gaussian import GaussianProcess
25+
from hackable_diffusion.lib.corruption.riemannian import RiemannianProcess
2526
from hackable_diffusion.lib.corruption.schedules import CosineDiscreteSchedule
2627
from hackable_diffusion.lib.corruption.schedules import CosineSchedule
2728
from hackable_diffusion.lib.corruption.schedules import DiscreteSchedule
@@ -31,8 +32,10 @@
3132
from hackable_diffusion.lib.corruption.schedules import InverseCosineSchedule
3233
from hackable_diffusion.lib.corruption.schedules import LinearDiffusionSchedule
3334
from hackable_diffusion.lib.corruption.schedules import LinearDiscreteSchedule
35+
from hackable_diffusion.lib.corruption.schedules import LinearRiemannianSchedule
3436
from hackable_diffusion.lib.corruption.schedules import PolynomialDiscreteSchedule
3537
from hackable_diffusion.lib.corruption.schedules import RFSchedule
38+
from hackable_diffusion.lib.corruption.schedules import RiemannianSchedule
3639
from hackable_diffusion.lib.corruption.schedules import Schedule
3740
from hackable_diffusion.lib.corruption.schedules import ShiftedSchedule
3841
from hackable_diffusion.lib.corruption.schedules import SquareCosineDiscreteSchedule
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2026 Hackable Diffusion Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Riemannian Flow Matching corruption process."""
16+
17+
import dataclasses
18+
from typing import Any
19+
20+
from hackable_diffusion.lib import hd_typing
21+
from hackable_diffusion.lib import manifolds
22+
from hackable_diffusion.lib import utils
23+
from hackable_diffusion.lib.corruption import base
24+
from hackable_diffusion.lib.corruption import schedules
25+
import kauldron.ktyping as kt
26+
27+
################################################################################
28+
# MARK: Type Aliases
29+
################################################################################
30+
31+
PRNGKey = hd_typing.PRNGKey
32+
DataArray = hd_typing.DataArray
33+
TimeArray = hd_typing.TimeArray
34+
TargetInfo = hd_typing.TargetInfo
35+
36+
################################################################################
37+
# MARK: Riemannian Flow Matching corruption process
38+
################################################################################
39+
40+
41+
@dataclasses.dataclass(kw_only=True, frozen=True)
42+
class RiemannianProcess(base.CorruptionProcess):
43+
"""Riemannian Flow Matching corruption process.
44+
45+
This is based on https://arxiv.org/abs/2302.03660.
46+
47+
Given a schedule with interpolation parameter alpha(t):
48+
x_t = geodesic(x_0, x_1, alpha(t))
49+
target = alpha'(t) * velocity(x_0, x_1, alpha(t))
50+
"""
51+
52+
manifold: manifolds.Manifold
53+
schedule: schedules.RiemannianSchedule
54+
55+
@kt.typechecked
56+
def sample_from_invariant(
57+
self,
58+
key: PRNGKey,
59+
data_spec: DataArray,
60+
) -> DataArray:
61+
"""Sample from the base distribution (uniform) on the manifold."""
62+
return self.manifold.random_uniform(key, data_spec.shape)
63+
64+
@kt.typechecked
65+
def corrupt(
66+
self,
67+
key: PRNGKey,
68+
x0: DataArray,
69+
time: TimeArray,
70+
) -> tuple[DataArray, TargetInfo]:
71+
x1 = self.sample_from_invariant(key, data_spec=x0)
72+
73+
# Evaluate schedule: alpha(t) is the geodesic interpolation parameter.
74+
alpha_t = utils.bcast_right(self.schedule.alpha(time), x0.ndim)
75+
alpha_dot_t = utils.bcast_right(self.schedule.alpha_dot(time), x0.ndim)
76+
77+
# x_t = geodesic(x0, x1, alpha(t)).
78+
xt = manifolds.geodesic(self.manifold, x1, x0, alpha_t)
79+
80+
# By chain rule: d/dt x_t = alpha'(t) * velocity(x0, x1, alpha(t)).
81+
velocity = alpha_dot_t * self.manifold.velocity(x1, x0, alpha_t)
82+
83+
target_info = {
84+
'x0': x0,
85+
'x1': x1,
86+
'velocity': velocity,
87+
}
88+
89+
return xt, target_info
90+
91+
@kt.typechecked
92+
def convert_predictions(
93+
self,
94+
prediction: TargetInfo,
95+
xt: DataArray,
96+
time: TimeArray,
97+
) -> TargetInfo:
98+
"""Convert predictions to velocity parameterization."""
99+
if 'velocity' in prediction:
100+
return prediction
101+
raise NotImplementedError(
102+
'Only velocity prediction is supported for RFM currently.'
103+
)
104+
105+
@kt.typechecked
106+
def get_schedule_info(self, time: TimeArray) -> dict[str, Any]:
107+
return self.schedule.evaluate(time)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright 2026 Hackable Diffusion Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for Riemannian Flow Matching corruption process."""
16+
17+
from absl.testing import absltest
18+
from hackable_diffusion.lib import manifolds
19+
from hackable_diffusion.lib.corruption import riemannian
20+
from hackable_diffusion.lib.corruption import schedules
21+
import jax
22+
import jax.numpy as jnp
23+
import numpy as np
24+
25+
26+
def _make_process(manifold):
27+
return riemannian.RiemannianProcess(
28+
manifold=manifold,
29+
schedule=schedules.LinearRiemannianSchedule(),
30+
)
31+
32+
33+
class SphereCorruptionTest(absltest.TestCase):
34+
35+
def test_corrupt(self):
36+
manifold = manifolds.Sphere()
37+
process = _make_process(manifold)
38+
key = jax.random.PRNGKey(0)
39+
40+
batch_size = 8
41+
x0 = manifold.random_uniform(key, (batch_size, 3))
42+
time = jnp.linspace(0, 1, batch_size)
43+
44+
xt, target_info = process.corrupt(key, x0, time)
45+
46+
# xt should be on the sphere.
47+
norms = jnp.linalg.norm(xt, axis=-1)
48+
np.testing.assert_allclose(norms, 1.0, atol=1e-5)
49+
50+
# Velocity should be tangent to the sphere at xt, i.e. <xt, vel> = 0.
51+
vel = target_info['velocity']
52+
self.assertEqual(vel.shape, (batch_size, 3))
53+
inner_products = jnp.sum(xt * vel, axis=-1)
54+
np.testing.assert_allclose(inner_products, 0.0, atol=1e-5)
55+
56+
def test_velocity_at_t1(self):
57+
"""At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1)."""
58+
manifold = manifolds.Sphere()
59+
process = _make_process(manifold)
60+
key = jax.random.PRNGKey(0)
61+
62+
x0 = jnp.array([[1.0, 0.0, 0.0]])
63+
t1 = jnp.array([1.0])
64+
xt1, target1 = process.corrupt(key, x0, t1)
65+
np.testing.assert_allclose(xt1, x0, atol=1e-5)
66+
67+
v1 = target1['velocity']
68+
x1_sampled = target1['x1']
69+
v_log = manifold.log(x0, x1_sampled)
70+
np.testing.assert_allclose(v1, -v_log, atol=1e-5)
71+
72+
73+
class SO3CorruptionTest(absltest.TestCase):
74+
75+
def test_corrupt(self):
76+
manifold = manifolds.SO3()
77+
process = _make_process(manifold)
78+
key = jax.random.PRNGKey(1)
79+
80+
batch_size = 8
81+
x0 = manifold.random_uniform(key, (batch_size, 3, 3))
82+
time = jnp.linspace(0, 1, batch_size)
83+
84+
xt, target_info = process.corrupt(key, x0, time)
85+
86+
# xt should be a valid rotation: R^T R = I and det(R) = 1.
87+
rtrt = jnp.matmul(jnp.swapaxes(xt, -2, -1), xt)
88+
eyes = jnp.broadcast_to(jnp.eye(3), rtrt.shape)
89+
np.testing.assert_allclose(rtrt, eyes, atol=1e-5)
90+
np.testing.assert_allclose(jnp.linalg.det(xt), 1.0, atol=1e-5)
91+
92+
# Velocity should be in the tangent space: x^T v is skew-symmetric.
93+
vel = target_info['velocity']
94+
self.assertEqual(vel.shape, (batch_size, 3, 3))
95+
96+
def test_velocity_at_t1(self):
97+
"""At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1)."""
98+
manifold = manifolds.SO3()
99+
process = _make_process(manifold)
100+
key = jax.random.PRNGKey(1)
101+
102+
x0 = jnp.eye(3)[None, ...] # (1, 3, 3)
103+
t1 = jnp.array([1.0])
104+
xt1, target1 = process.corrupt(key, x0, t1)
105+
np.testing.assert_allclose(xt1, x0, atol=1e-5)
106+
107+
v1 = target1['velocity']
108+
x1_sampled = target1['x1']
109+
v_log = manifold.log(x0, x1_sampled)
110+
np.testing.assert_allclose(v1, -v_log, atol=1e-4)
111+
112+
113+
class TorusCorruptionTest(absltest.TestCase):
114+
115+
def test_corrupt(self):
116+
manifold = manifolds.Torus()
117+
process = _make_process(manifold)
118+
key = jax.random.PRNGKey(2)
119+
120+
batch_size = 8
121+
dim = 4
122+
x0 = manifold.random_uniform(key, (batch_size, dim))
123+
time = jnp.linspace(0, 1, batch_size)
124+
125+
xt, target_info = process.corrupt(key, x0, time)
126+
127+
# xt should be in [0, 1).
128+
self.assertTrue(jnp.all(xt >= 0.0))
129+
self.assertTrue(jnp.all(xt < 1.0))
130+
131+
vel = target_info['velocity']
132+
self.assertEqual(vel.shape, (batch_size, dim))
133+
134+
def test_velocity_at_t1(self):
135+
"""At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1)."""
136+
manifold = manifolds.Torus()
137+
process = _make_process(manifold)
138+
key = jax.random.PRNGKey(2)
139+
140+
x0 = jnp.array([[0.1, 0.5, 0.9]])
141+
t1 = jnp.array([1.0])
142+
xt1, target1 = process.corrupt(key, x0, t1)
143+
np.testing.assert_allclose(xt1, x0, atol=1e-5)
144+
145+
v1 = target1['velocity']
146+
x1_sampled = target1['x1']
147+
v_log = manifold.log(x0, x1_sampled)
148+
np.testing.assert_allclose(v1, -v_log, atol=1e-5)
149+
150+
151+
if __name__ == '__main__':
152+
absltest.main()

hackable_diffusion/lib/corruption/schedules.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,59 @@ def evaluate(self, time: TimeArray) -> dict[str, TimeArray]:
109109
SimplicialSchedule = DiscreteSchedule
110110

111111

112+
################################################################################
113+
# MARK: Riemannian Schedules
114+
################################################################################
115+
116+
117+
class RiemannianSchedule(abc.ABC, Schedule):
118+
"""Base class for Riemannian schedules.
119+
120+
Controls the geodesic interpolation via alpha(t):
121+
x_t = geodesic(x_0, x_1, alpha(t))
122+
v_t = alpha'(t) * velocity(x_0, x_1, alpha(t))
123+
124+
Subclasses must implement `alpha`.
125+
"""
126+
127+
@abc.abstractmethod
128+
def alpha(self, time: TimeArray) -> TimeArray:
129+
"""The geodesic interpolation parameter at time t."""
130+
131+
def alpha_dot(self, time: TimeArray) -> TimeArray:
132+
"""Time derivative of alpha. Defaults to autodiff."""
133+
return utils.egrad(self.alpha)(time)
134+
135+
@kt.typechecked
136+
def evaluate(self, time: TimeArray) -> dict[str, TimeArray]:
137+
return {
138+
'time': time,
139+
'alpha': self.alpha(time),
140+
'alpha_dot': self.alpha_dot(time),
141+
}
142+
143+
144+
class LinearRiemannianSchedule(RiemannianSchedule):
145+
"""Linear Riemannian schedule: alpha(t) = 1.0 - t.
146+
147+
This is the standard flow matching schedule where the geodesic interpolation
148+
parameter equals time directly.
149+
Note that contrary to the original Riemannian Flow Matching, we assume that at
150+
time t=0, the process is close to the data distribution, and at time t=1,
151+
the process is close to the target distribution.
152+
Hence, we use alpha(t) = 1.0 - t, and alpha_dot(t) = -1.0m instead of
153+
alpha(t) = t, and alpha_dot(t) = 1.0.
154+
"""
155+
156+
@kt.typechecked
157+
def alpha(self, time: TimeArray) -> TimeArray:
158+
return 1.0 - time
159+
160+
@kt.typechecked
161+
def alpha_dot(self, time: TimeArray) -> TimeArray:
162+
return -jnp.ones_like(time)
163+
164+
112165
################################################################################
113166
# MARK: Gaussian Schedules
114167
################################################################################

0 commit comments

Comments
 (0)