Skip to content

Commit 17551f5

Browse files
committed
Reflect name change SequentialNetwork -> SequenceNetwork
1 parent 0f53c8a commit 17551f5

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

bayesflow/summary_networks.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,8 @@ def __init__(self, *args, **kwargs):
410410
super().__init__(*args, **kwargs)
411411

412412

413-
class SequentialNetwork(tf.keras.Model):
414-
"""Implements a sequence of `MultiConv1D` layers followed by an LSTM network.
413+
class SequenceNetwork(tf.keras.Model):
414+
"""Implements a sequence of `MultiConv1D` layers followed by an (bidirectional) LSTM network.
415415
416416
For details and rationale, see [1]:
417417
@@ -484,6 +484,22 @@ def call(self, x, **kwargs):
484484
return out
485485

486486

487+
class SequentialNetwork(SequenceNetwork):
488+
"""Deprecated class for amortizer posterior estimation."""
489+
490+
def __init_subclass__(cls, **kwargs):
491+
warn(f"{cls.__name__} will be deprecated. Use `SequenceNetwork` instead.", DeprecationWarning, stacklevel=2)
492+
super().__init_subclass__(**kwargs)
493+
494+
def __init__(self, *args, **kwargs):
495+
warn(
496+
f"{self.__class__.__name__} will be deprecated. Use `SequenceNetwork` instead.",
497+
DeprecationWarning,
498+
stacklevel=2,
499+
)
500+
super().__init__(*args, **kwargs)
501+
502+
487503
class SplitNetwork(tf.keras.Model):
488504
"""Implements a vertical stack of networks and concatenates their individual outputs. Allows for splitting
489505
of data to provide an individual network for each split of the data.
@@ -565,7 +581,7 @@ def __init__(self, networks_list, **kwargs):
565581
566582
Example: For two-level hierarchical models with the assumption of temporal dependencies on the lowest
567583
hierarchical level (e.g., observational level) and exchangeable units at the higher level
568-
(e.g., group level), a list of [SequentialNetwork(), DeepSet()] could be passed.
584+
(e.g., group level), a list of [SequenceNetwork(), DeepSet()] could be passed.
569585
570586
----------
571587

0 commit comments

Comments
 (0)