-
Notifications
You must be signed in to change notification settings - Fork 78
Description
In addition to #242, we could do with some general clean-up of the summary networks. This is mostly in relation to their constructor arguments, most of which are not meaningful to non-devs. I think it would also help to simplify the implementation, e.g. by dropping the multiple poolings used in DeepSet in favor of a single pooling.
This might make the networks slightly less configurable, but could greatly improve their overall usability in terms of lowering the entry barrier to configuration. For most users, I think this would be a benefit, and power users could still just implement their own version to circumvent the potential reduction of configuration options.
This is just an idea meant for discussion. I would be glad to hear your thoughts!
Here is an example implementation for a reduced DeepSet:
@serializable(package="bayesflow.networks")
class DeepSet(SummaryNetwork):
def __init__(
self,
*,
summary_dim: int = 16,
widths: (Sequence[int], Sequence[int]) = ((128, 128), (128, 128)),
pooling: Literal["sum", "mean"] = "mean",
activation: str = "gelu",
dropout: float | None = 0.05,
**kwargs,
):
super().__init__(**kwargs)
self.summary_dim = summary_dim
self.pooling = pooling
self.equivariant_mlp = MLP(
widths=widths[0],
activation=activation,
dropout=dropout,
)
self.pooling_layer = PoolingLayer(pooling)
self.invariant_mlp = MLP(
widths=widths[1],
activation=activation,
dropout=dropout,
)
self.output_projector = keras.layers.Dense(summary_dim, activation=None)
@property
def pooling(self):
return self.pooling_layer.method
@pooling.setter
def pooling(self, pooling: Literal["sum", "mean"]):
self.pooling_layer = PoolingLayer(pooling)
def build(self, input_shape):
super().build(input_shape)
self.call(keras.ops.zeros(input_shape))
def call(self, x: Tensor, **kwargs) -> Tensor:
x = self.equivariant_mlp(x)
x = self.pooling_layer(x)
x = self.invariant_mlp(x)
x = self.output_projector(x)
return xwhich uses this PoolingLayer:
@serializable(package="bayesflow.networks")
class PoolingLayer(keras.Layer):
def __init__(self, method: Literal["mean", "sum"] = "mean", axis: int = 1, keepdims: bool = False, **kwargs):
super().__init__(**kwargs)
self.method = method
self.axis = axis
self.keepdims = keepdims
# stateless layers do not need to be built
self.built = True
def call(self, x: Tensor) -> Tensor:
match self.method:
case "mean":
x = keras.ops.mean(x, axis=self.axis, keepdims=self.keepdims)
case "sum":
x = keras.ops.sum(x, axis=self.axis, keepdims=self.keepdims)
case other:
raise ValueError(f"Unknown pooling method: {other!r}")
return xMetadata
Metadata
Assignees
Labels
Type
Projects
Status