Skip to content

Commit 94c1688

Browse files
committed
add scale_regularization parameter to sinkhorn, fix max_steps=None for sinkhorn numpy
1 parent 91e84ba commit 94c1688

File tree

1 file changed

+60
-10
lines changed

1 file changed

+60
-10
lines changed

bayesflow/utils/optimal_transport/sinkhorn.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def sinkhorn(
1919
max_steps: int | None = 10_000,
2020
tolerance: float = 1e-6,
2121
numpy: bool = False,
22+
scale_regularization: bool = True,
2223
) -> (Tensor, Tensor):
2324
"""
2425
Matches elements from x2 onto x1 using the Sinkhorn-Knopp algorithm.
@@ -57,6 +58,11 @@ def sinkhorn(
5758
:param tolerance: Absolute tolerance for convergence.
5859
Default: 1e-6
5960
61+
:param scale_regularization: Whether to scale the regularization parameter with the half-mean of the cost matrix.
62+
This makes the value of the regularization parameter robust to the dimensionality and typical range of the
63+
samples.
64+
Default: True
65+
6066
:return: Tensors of shapes (n, ...) and (m, ...)
6167
x1 and x2 in optimal transport permutation order.
6268
"""
@@ -69,6 +75,7 @@ def sinkhorn(
6975
max_steps=max_steps,
7076
tolerance=tolerance,
7177
numpy=numpy,
78+
scale_regularization=scale_regularization,
7279
)
7380

7481
if numpy:
@@ -91,6 +98,7 @@ def sinkhorn_indices(
9198
max_steps: int | None = 10_000,
9299
tolerance: float = 1e-6,
93100
numpy: bool = False,
101+
scale_regularization: bool = True,
94102
) -> Tensor | np.ndarray:
95103
"""
96104
Samples a set of optimal transport permutation indices using the Sinkhorn-Knopp algorithm.
@@ -118,6 +126,11 @@ def sinkhorn_indices(
118126
119127
:param numpy: Whether to use numpy or keras backend.
120128
129+
:param scale_regularization: Whether to scale the regularization parameter with the half-mean of the cost matrix.
130+
This makes the value of the regularization parameter robust to the dimensionality and typical range of the
131+
samples.
132+
Default: True
133+
121134
:return: Tensor of shape (n,)
122135
Randomly sampled optimal permutation indices for the first distribution.
123136
"""
@@ -129,6 +142,7 @@ def sinkhorn_indices(
129142
max_steps=max_steps,
130143
tolerance=tolerance,
131144
numpy=numpy,
145+
scale_regularization=scale_regularization,
132146
)
133147

134148
if numpy:
@@ -148,7 +162,14 @@ def sinkhorn_indices(
148162

149163

150164
def sinkhorn_plan(
151-
x1: Tensor, x2: Tensor, cost: Tensor, regularization: float, max_steps: int, tolerance: float, numpy: bool = False
165+
x1: Tensor,
166+
x2: Tensor,
167+
cost: str | Tensor,
168+
regularization: float,
169+
max_steps: int,
170+
tolerance: float,
171+
numpy: bool = False,
172+
scale_regularization: bool = True,
152173
) -> Tensor:
153174
"""
154175
Computes the Sinkhorn-Knopp optimal transport plan.
@@ -172,19 +193,39 @@ def sinkhorn_plan(
172193
:param numpy: Whether to use numpy or keras backend.
173194
Default: False
174195
196+
:param scale_regularization: Whether to scale the regularization parameter with the half-mean of the cost matrix.
197+
This makes the value of the regularization parameter robust to the dimensionality and typical range of the
198+
samples.
199+
Default: True
200+
175201
:return: Tensor of shape (n, m)
176202
The transport probabilities.
177203
"""
178204
cost = find_cost(cost, x1, x2, numpy=numpy)
179205

180206
if numpy:
181-
return sinkhorn_plan_numpy(cost=cost, regularization=regularization, max_steps=max_steps, tolerance=tolerance)
182-
return sinkhorn_plan_keras(cost=cost, regularization=regularization, max_steps=max_steps, tolerance=tolerance)
207+
return sinkhorn_plan_numpy(
208+
cost=cost,
209+
regularization=regularization,
210+
max_steps=max_steps,
211+
tolerance=tolerance,
212+
scale_regularization=scale_regularization,
213+
)
214+
return sinkhorn_plan_keras(
215+
cost=cost,
216+
regularization=regularization,
217+
max_steps=max_steps,
218+
tolerance=tolerance,
219+
scale_regularization=scale_regularization,
220+
)
183221

184222

185-
def sinkhorn_plan_keras(cost: Tensor, regularization: float, max_steps: int, tolerance: float) -> Tensor:
186-
# scale regularization with the half-mean of the cost
187-
regularization = 0.5 * regularization * keras.ops.mean(cost)
223+
def sinkhorn_plan_keras(
224+
cost: Tensor, regularization: float, max_steps: int, tolerance: float, scale_regularization: bool
225+
) -> Tensor:
226+
if scale_regularization:
227+
# scale regularization with the half-mean of the cost
228+
regularization = 0.5 * regularization * keras.ops.mean(cost)
188229

189230
# initialize the transport plan from a gaussian kernel
190231
plan = keras.ops.exp(-0.5 * cost / regularization)
@@ -227,14 +268,21 @@ def warn():
227268
return plan
228269

229270

230-
def sinkhorn_plan_numpy(cost: np.ndarray, regularization: float, max_steps: int, tolerance: float) -> np.ndarray:
231-
# scale regularization with the half-mean of the cost
232-
regularization = 0.5 * regularization * np.mean(cost)
271+
def sinkhorn_plan_numpy(
272+
cost: np.ndarray, regularization: float, max_steps: int, tolerance: float, scale_regularization: bool
273+
) -> np.ndarray:
274+
if scale_regularization:
275+
# scale regularization with the half-mean of the cost
276+
regularization = 0.5 * regularization * np.mean(cost)
233277

234278
# initialize the transport plan from a gaussian kernel
235279
plan = np.exp(-0.5 * cost / regularization)
236280

237-
for _ in range(max_steps):
281+
step = 0
282+
while True:
283+
if step >= max_steps:
284+
break
285+
238286
# check convergence: the plan should be doubly stochastic
239287
marginals = np.sum(plan, axis=0), np.sum(plan, axis=1)
240288
deviations = np.abs(marginals[0] - 1.0), np.abs(marginals[1] - 1.0)
@@ -246,6 +294,8 @@ def sinkhorn_plan_numpy(cost: np.ndarray, regularization: float, max_steps: int,
246294
plan = softmax(plan, axis=0)
247295
plan = softmax(plan, axis=1)
248296

297+
step += 1
298+
249299
marginals = np.sum(plan, axis=0)
250300
deviations = np.abs(marginals - 1.0)
251301

0 commit comments

Comments
 (0)