From 15305ebf87dad8ff17e2f4c1b917c0be465c6c83 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 23 Sep 2025 13:17:12 +0200 Subject: [PATCH] fix scm --- .../stable_consistency_model.py | 86 +++++++++++++------ .../networks/embeddings/fourier_embedding.py | 22 ++++- 2 files changed, 81 insertions(+), 27 deletions(-) diff --git a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py index f4fe37611..8dd5ecf3b 100644 --- a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py +++ b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py @@ -5,15 +5,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor -from bayesflow.utils import ( - logging, - jvp, - concatenate_valid, - find_network, - expand_right_as, - expand_right_to, - layer_kwargs, -) +from bayesflow.utils import logging, jvp, find_network, expand_right_as, expand_right_to, layer_kwargs, tensor_utils from bayesflow.utils.serialization import deserialize, serializable, serialize from bayesflow.networks import InferenceNetwork @@ -83,6 +75,11 @@ def __init__( includes depth, hidden sizes, and non-linearity choices. embedding_kwargs : dict[str, any], optional, default=None Keyword arguments for the time embedding layer(s) used in the model + concatenate_subnet_input: bool, optional + Flag for advanced users to control whether all inputs to the subnet should be concatenated + into a single vector or passed as separate arguments. If set to False, the subnet + must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio), + and optional 'conditions'. Default is True. **kwargs Additional keyword arguments passed to the parent ``InferenceNetwork`` initializer (e.g., ``name``, ``dtype``, or ``trainable``). @@ -97,6 +94,7 @@ def __init__( self.subnet_projector = keras.layers.Dense( units=None, bias_initializer="zeros", kernel_initializer="zeros", name="subnet_projector" ) + self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True) weight_mlp_kwargs = weight_mlp_kwargs or {} weight_mlp_kwargs = StableConsistencyModel.WEIGHT_MLP_DEFAULT_CONFIG | weight_mlp_kwargs @@ -107,6 +105,7 @@ def __init__( ) embedding_kwargs = embedding_kwargs or {} + self.embedding_kwargs = embedding_kwargs self.time_emb = FourierEmbedding(**embedding_kwargs) self.time_emb_dim = self.time_emb.embed_dim @@ -124,6 +123,8 @@ def get_config(self): config = { "subnet": self.subnet, "sigma": self.sigma, + "embedding_kwargs": self.embedding_kwargs, + "concatenate_subnet_input": self._concatenate_subnet_input, } return base_config | serialize(config) @@ -151,17 +152,22 @@ def build(self, xz_shape, conditions_shape=None): # construct input shape for subnet and subnet projector input_shape = list(xz_shape) - # time vector - input_shape[-1] += self.time_emb_dim + 1 - - if conditions_shape is not None: - input_shape[-1] += conditions_shape[-1] - - input_shape = tuple(input_shape) - - self.subnet.build(input_shape) - - input_shape = self.subnet.compute_output_shape(input_shape) + if self._concatenate_subnet_input: + # construct time vector + input_shape[-1] += self.time_emb_dim + 1 + if conditions_shape is not None: + input_shape[-1] += conditions_shape[-1] + input_shape = tuple(input_shape) + + self.subnet.build(input_shape) + input_shape = self.subnet.compute_output_shape(input_shape) + else: + # Multiple separate inputs + time_shape = tuple(xz_shape[:-1]) + (self.time_emb_dim + 1,) # same batch/sequence dims, 1 feature + self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape) + input_shape = self.subnet.compute_output_shape( + x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape + ) self.subnet_projector.build(input_shape) # input shape for time embedding @@ -173,6 +179,35 @@ def build(self, xz_shape, conditions_shape=None): input_shape = self.weight_fn.compute_output_shape(input_shape) self.weight_fn_projector.build(input_shape) + def _apply_subnet( + self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False + ) -> Tensor | tuple[Tensor, Tensor, Tensor]: + """ + Prepares and passes the input to the subnet either by concatenating the latent variable `x`, + the time `t`, and optional conditions or by returning them separately. + + Parameters + ---------- + x : Tensor + The parameter tensor, typically of shape (..., D), but can vary. + t : Tensor + The time tensor, typically of shape (..., 1). + conditions : Tensor, optional + The optional conditioning tensor (e.g. parameters). + training : bool, optional + The training mode flag, which can be used to control behavior during training. + + Returns + ------- + Tensor + The output tensor from the subnet. + """ + if self._concatenate_subnet_input: + xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1) + return self.subnet(xtc, training=training) + else: + return self.subnet(x=x, t=t, conditions=conditions, training=training) + def _forward(self, x: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: # Consistency Models only learn the direction from noise distribution # to target distribution, so we cannot implement this function. @@ -218,7 +253,6 @@ def consistency_function( t: Tensor, conditions: Tensor = None, training: bool = False, - **kwargs, ) -> Tensor: """Compute consistency function at time t. @@ -235,8 +269,8 @@ def consistency_function( **kwargs : dict, optional, default: {} Additional keyword arguments passed to the inner network. """ - xtc = concatenate_valid([x / self.sigma, self.time_emb(t), conditions], axis=-1) - f = self.subnet_projector(self.subnet(xtc, training=training, **kwargs)) + subnet_out = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=training) + f = self.subnet_projector(subnet_out) out = ops.cos(t) * x - ops.sin(t) * self.sigma * f return out @@ -273,7 +307,7 @@ def compute_metrics( r = 1.0 # TODO: if consistency distillation training (not supported yet) is unstable, add schedule here def f_teacher(x, t): - o = self.subnet(concatenate_valid([x, self.time_emb(t), conditions], axis=-1), training=stage == "training") + o = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training") return self.subnet_projector(o) primals = (xt / self.sigma, t) @@ -287,8 +321,8 @@ def f_teacher(x, t): cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt) # calculate output of the network - xtc = concatenate_valid([xt / self.sigma, self.time_emb(t), conditions], axis=-1) - student_out = self.subnet_projector(self.subnet(xtc, training=stage == "training")) + subnet_out = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training") + student_out = self.subnet_projector(subnet_out) # calculate the tangent g = -(ops.cos(t) ** 2) * (self.sigma * teacher_output - dxtdt) - r * ops.cos(t) * ops.sin(t) * ( diff --git a/bayesflow/networks/embeddings/fourier_embedding.py b/bayesflow/networks/embeddings/fourier_embedding.py index 21924ee60..dc30e654e 100644 --- a/bayesflow/networks/embeddings/fourier_embedding.py +++ b/bayesflow/networks/embeddings/fourier_embedding.py @@ -4,7 +4,8 @@ from keras import ops from bayesflow.types import Tensor -from bayesflow.utils.serialization import serializable +from bayesflow.utils import layer_kwargs +from bayesflow.utils.serialization import serializable, serialize, deserialize @serializable("bayesflow.networks") @@ -47,6 +48,8 @@ def __init__( self.scale = scale self.embed_dim = embed_dim self.include_identity = include_identity + self.initializer = initializer + self.trainable = trainable def call(self, t: Tensor) -> Tensor: """Embeds the one-dimensional time scalar into a higher-dimensional Fourier embedding. @@ -68,3 +71,20 @@ def call(self, t: Tensor) -> Tensor: else: emb = ops.concatenate([ops.sin(proj), ops.cos(proj)], axis=-1) return emb + + def get_config(self): + base_config = super().get_config() + base_config = layer_kwargs(base_config) + + config = { + "embed_dim": self.embed_dim, + "scale": self.scale, + "initializer": self.initializer, + "trainable": self.trainable, + "include_identity": self.include_identity, + } + return base_config | serialize(config) + + @classmethod + def from_config(cls, config, custom_objects=None): + return cls(**deserialize(config, custom_objects=custom_objects))