We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 70c6235 commit b0af372Copy full SHA for b0af372
bayesflow/__init__.py
@@ -34,6 +34,17 @@ def setup():
34
35
logging.debug(f"Using backend {keras.backend.backend()!r}")
36
37
+ if keras.backend.backend() == "torch":
38
+ import torch
39
+
40
+ torch.autograd.set_grad_enabled(False)
41
42
+ logging.warning(
43
+ "When using torch backend, we need to disable autograd by default to avoid excessive memory usage. Use\n"
44
+ "with torch.enable_grad():\n"
45
+ "in contexts where you need gradients (e.g. custom training loops)."
46
+ )
47
48
49
# call and clean up namespace
50
setup()
0 commit comments