@@ -19,6 +19,7 @@ def sinkhorn(
1919 max_steps : int | None = 10_000 ,
2020 tolerance : float = 1e-6 ,
2121 numpy : bool = False ,
22+ scale_regularization : bool = True ,
2223) -> (Tensor , Tensor ):
2324 """
2425 Matches elements from x2 onto x1 using the Sinkhorn-Knopp algorithm.
@@ -57,6 +58,11 @@ def sinkhorn(
5758 :param tolerance: Absolute tolerance for convergence.
5859 Default: 1e-6
5960
61+ :param scale_regularization: Whether to scale the regularization parameter with the half-mean of the cost matrix.
62+ This makes the value of the regularization parameter robust to the dimensionality and typical range of the
63+ samples.
64+ Default: True
65+
6066 :return: Tensors of shapes (n, ...) and (m, ...)
6167 x1 and x2 in optimal transport permutation order.
6268 """
@@ -69,6 +75,7 @@ def sinkhorn(
6975 max_steps = max_steps ,
7076 tolerance = tolerance ,
7177 numpy = numpy ,
78+ scale_regularization = scale_regularization ,
7279 )
7380
7481 if numpy :
@@ -91,6 +98,7 @@ def sinkhorn_indices(
9198 max_steps : int | None = 10_000 ,
9299 tolerance : float = 1e-6 ,
93100 numpy : bool = False ,
101+ scale_regularization : bool = True ,
94102) -> Tensor | np .ndarray :
95103 """
96104 Samples a set of optimal transport permutation indices using the Sinkhorn-Knopp algorithm.
@@ -118,6 +126,11 @@ def sinkhorn_indices(
118126
119127 :param numpy: Whether to use numpy or keras backend.
120128
129+ :param scale_regularization: Whether to scale the regularization parameter with the half-mean of the cost matrix.
130+ This makes the value of the regularization parameter robust to the dimensionality and typical range of the
131+ samples.
132+ Default: True
133+
121134 :return: Tensor of shape (n,)
122135 Randomly sampled optimal permutation indices for the first distribution.
123136 """
@@ -129,6 +142,7 @@ def sinkhorn_indices(
129142 max_steps = max_steps ,
130143 tolerance = tolerance ,
131144 numpy = numpy ,
145+ scale_regularization = scale_regularization ,
132146 )
133147
134148 if numpy :
@@ -148,7 +162,14 @@ def sinkhorn_indices(
148162
149163
150164def sinkhorn_plan (
151- x1 : Tensor , x2 : Tensor , cost : Tensor , regularization : float , max_steps : int , tolerance : float , numpy : bool = False
165+ x1 : Tensor ,
166+ x2 : Tensor ,
167+ cost : str | Tensor ,
168+ regularization : float ,
169+ max_steps : int ,
170+ tolerance : float ,
171+ numpy : bool = False ,
172+ scale_regularization : bool = True ,
152173) -> Tensor :
153174 """
154175 Computes the Sinkhorn-Knopp optimal transport plan.
@@ -172,19 +193,39 @@ def sinkhorn_plan(
172193 :param numpy: Whether to use numpy or keras backend.
173194 Default: False
174195
196+ :param scale_regularization: Whether to scale the regularization parameter with the half-mean of the cost matrix.
197+ This makes the value of the regularization parameter robust to the dimensionality and typical range of the
198+ samples.
199+ Default: True
200+
175201 :return: Tensor of shape (n, m)
176202 The transport probabilities.
177203 """
178204 cost = find_cost (cost , x1 , x2 , numpy = numpy )
179205
180206 if numpy :
181- return sinkhorn_plan_numpy (cost = cost , regularization = regularization , max_steps = max_steps , tolerance = tolerance )
182- return sinkhorn_plan_keras (cost = cost , regularization = regularization , max_steps = max_steps , tolerance = tolerance )
207+ return sinkhorn_plan_numpy (
208+ cost = cost ,
209+ regularization = regularization ,
210+ max_steps = max_steps ,
211+ tolerance = tolerance ,
212+ scale_regularization = scale_regularization ,
213+ )
214+ return sinkhorn_plan_keras (
215+ cost = cost ,
216+ regularization = regularization ,
217+ max_steps = max_steps ,
218+ tolerance = tolerance ,
219+ scale_regularization = scale_regularization ,
220+ )
183221
184222
185- def sinkhorn_plan_keras (cost : Tensor , regularization : float , max_steps : int , tolerance : float ) -> Tensor :
186- # scale regularization with the half-mean of the cost
187- regularization = 0.5 * regularization * keras .ops .mean (cost )
223+ def sinkhorn_plan_keras (
224+ cost : Tensor , regularization : float , max_steps : int , tolerance : float , scale_regularization : bool
225+ ) -> Tensor :
226+ if scale_regularization :
227+ # scale regularization with the half-mean of the cost
228+ regularization = 0.5 * regularization * keras .ops .mean (cost )
188229
189230 # initialize the transport plan from a gaussian kernel
190231 plan = keras .ops .exp (- 0.5 * cost / regularization )
@@ -227,14 +268,21 @@ def warn():
227268 return plan
228269
229270
230- def sinkhorn_plan_numpy (cost : np .ndarray , regularization : float , max_steps : int , tolerance : float ) -> np .ndarray :
231- # scale regularization with the half-mean of the cost
232- regularization = 0.5 * regularization * np .mean (cost )
271+ def sinkhorn_plan_numpy (
272+ cost : np .ndarray , regularization : float , max_steps : int , tolerance : float , scale_regularization : bool
273+ ) -> np .ndarray :
274+ if scale_regularization :
275+ # scale regularization with the half-mean of the cost
276+ regularization = 0.5 * regularization * np .mean (cost )
233277
234278 # initialize the transport plan from a gaussian kernel
235279 plan = np .exp (- 0.5 * cost / regularization )
236280
237- for _ in range (max_steps ):
281+ step = 0
282+ while True :
283+ if step >= max_steps :
284+ break
285+
238286 # check convergence: the plan should be doubly stochastic
239287 marginals = np .sum (plan , axis = 0 ), np .sum (plan , axis = 1 )
240288 deviations = np .abs (marginals [0 ] - 1.0 ), np .abs (marginals [1 ] - 1.0 )
@@ -246,6 +294,8 @@ def sinkhorn_plan_numpy(cost: np.ndarray, regularization: float, max_steps: int,
246294 plan = softmax (plan , axis = 0 )
247295 plan = softmax (plan , axis = 1 )
248296
297+ step += 1
298+
249299 marginals = np .sum (plan , axis = 0 )
250300 deviations = np .abs (marginals - 1.0 )
251301
0 commit comments