Skip to content

Commit 97516f3

Browse files
committed
remove the is_symbolic_tensor check because this would otherwise skip the whole function for compiled contexts
1 parent 3b1c053 commit 97516f3

File tree

2 files changed

+0
-8
lines changed

2 files changed

+0
-8
lines changed

bayesflow/utils/optimal_transport/log_sinkhorn.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import keras
22

33
from .. import logging
4-
from ..tensor_utils import is_symbolic_tensor
54

65
from .euclidean import euclidean
76

@@ -27,9 +26,6 @@ def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8,
2726

2827
log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16)
2928

30-
if is_symbolic_tensor(log_plan):
31-
return log_plan
32-
3329
def contains_nans(plan):
3430
return keras.ops.any(keras.ops.isnan(plan))
3531

bayesflow/utils/optimal_transport/sinkhorn.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from bayesflow.types import Tensor
44

55
from .. import logging
6-
from ..tensor_utils import is_symbolic_tensor
76

87
from .euclidean import euclidean
98

@@ -76,9 +75,6 @@ def sinkhorn_plan(
7675
# initialize the transport plan from a gaussian kernel
7776
plan = keras.ops.exp(cost / -(regularization * keras.ops.mean(cost) + 1e-16))
7877

79-
if is_symbolic_tensor(plan):
80-
return plan
81-
8278
def contains_nans(plan):
8379
return keras.ops.any(keras.ops.isnan(plan))
8480

0 commit comments

Comments
 (0)