Skip to content

Commit 286bea6

Browse files
sinkhorn(x1, x2) now samples from log(plan) to receive assignments such that x2[assignments] matches x1
1 parent 4502a36 commit 286bea6

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

bayesflow/utils/optimal_transport/sinkhorn.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten
2727
:param seed: Random seed to use for sampling indices.
2828
Default: None, which means the seed will be auto-determined for non-compiled contexts.
2929
30-
:return: Tensor of shape (m,)
30+
:return: Tensor of shape (n,)
3131
Assignment indices for x2.
3232
3333
"""
34-
plan = sinkhorn_plan(x1, x2, **kwargs) # shape: (n, m)
34+
plan = sinkhorn_plan(x1, x2, **kwargs)
3535

36-
# we sample from plan.T to receive assignments of length m, with elements up to n
37-
assignments = keras.random.categorical(keras.ops.log(plan.T), num_samples=1, seed=seed)
36+
# we sample from log(plan) to receive assignments of length n, corresponding to indices of x2
37+
# such that x2[assignments] matches x1
38+
assignments = keras.random.categorical(keras.ops.log(plan), num_samples=1, seed=seed)
3839
assignments = keras.ops.squeeze(assignments, axis=1)
3940

4041
return assignments

0 commit comments

Comments
 (0)