Skip to content

Commit 0ed2524

Browse files
committed
Fix torch autograd bug
1 parent cfdfed6 commit 0ed2524

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

bayesflow/diagnostics/metrics/classifier_two_sample_test.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,6 @@ def classifier_two_sample_test(
7171
full training history.
7272
"""
7373

74-
# Convert tensors to numpy, if passed
75-
estimates = keras.ops.convert_to_numpy(estimates)
76-
targets = keras.ops.convert_to_numpy(targets)
77-
7874
# Error, if targets dim does not match estimates dim
7975
num_dims = estimates.shape[1]
8076
if not num_dims == targets.shape[1]:
@@ -111,15 +107,30 @@ def classifier_two_sample_test(
111107
monitor=f"val_{metric}", patience=patience, restore_best_weights=True
112108
)
113109

114-
history = classifier.fit(
115-
x=data,
116-
y=labels,
117-
epochs=max_epochs,
118-
batch_size=batch_size,
119-
verbose=0,
120-
callbacks=[early_stopping],
121-
validation_split=validation_split,
122-
)
110+
# For now, we need to enable grads, since we turn them off by default
111+
if keras.backend.backend() == "torch":
112+
import torch
113+
114+
with torch.enable_grad():
115+
history = classifier.fit(
116+
x=data,
117+
y=labels,
118+
epochs=max_epochs,
119+
batch_size=batch_size,
120+
verbose=0,
121+
callbacks=[early_stopping],
122+
validation_split=validation_split,
123+
)
124+
else:
125+
history = classifier.fit(
126+
x=data,
127+
y=labels,
128+
epochs=max_epochs,
129+
batch_size=batch_size,
130+
verbose=0,
131+
callbacks=[early_stopping],
132+
validation_split=validation_split,
133+
)
123134

124135
if return_metric_only:
125136
return history.history[f"val_{metric}"][-1]

0 commit comments

Comments
 (0)