Skip to content

Commit 382afef

Browse files
committed
fix old lstnet import and name
1 parent 2722d93 commit 382afef

File tree

6 files changed

+3
-3
lines changed

6 files changed

+3
-3
lines changed

bayesflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def setup():
3636
# turn off gradients by default
3737
import torch
3838

39-
torch.autograd.set_grad_enabled(False)
39+
# torch.autograd.set_grad_enabled(False)
4040

4141
logging.warning("Disabling gradients by default. Use\nwith torch.enable_grad():\nin custom training loops.")
4242

bayesflow/networks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .inference_network import InferenceNetwork
66
from .point_inference_network import PointInferenceNetwork
77
from .mlp import MLP
8-
from .lstnet import TimeSeriesNetwork
8+
from .time_series_network import TimeSeriesNetwork
99
from .summary_network import SummaryNetwork
1010
from .transformers import SetTransformer, TimeSeriesTransformer, FusionTransformer
1111

bayesflow/utils/workflow_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def find_summary_network(summary_network: SummaryNetwork | str, **kwargs) -> Sum
3636
return bayesflow.networks.FusionTransformer(**kwargs)
3737
case "time_series_transformer":
3838
return bayesflow.networks.TimeSeriesTransformer(**kwargs)
39-
case "lstnet":
39+
case "time_series_network":
4040
return bayesflow.networks.LSTNet(**kwargs)
4141
case str() as unknown_network:
4242
raise ValueError(f"Unknown summary network: '{unknown_network}'")

0 commit comments

Comments
 (0)