Skip to content

Commit b0af372

Browse files
committed
revert 35385df
also improve warning
1 parent 70c6235 commit b0af372

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

bayesflow/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ def setup():
3434

3535
logging.debug(f"Using backend {keras.backend.backend()!r}")
3636

37+
if keras.backend.backend() == "torch":
38+
import torch
39+
40+
torch.autograd.set_grad_enabled(False)
41+
42+
logging.warning(
43+
"When using torch backend, we need to disable autograd by default to avoid excessive memory usage. Use\n"
44+
"with torch.enable_grad():\n"
45+
"in contexts where you need gradients (e.g. custom training loops)."
46+
)
47+
3748

3849
# call and clean up namespace
3950
setup()

0 commit comments

Comments
 (0)