diff --git a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py index 6ce27fdf4..0f787d44f 100644 --- a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py +++ b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py @@ -105,7 +105,6 @@ def __init__( ) embedding_kwargs = embedding_kwargs or {} - self.embedding_kwargs = embedding_kwargs self.time_emb = FourierEmbedding(**embedding_kwargs) self.time_emb_dim = self.time_emb.embed_dim @@ -123,13 +122,14 @@ def get_config(self): config = { "subnet": self.subnet, "sigma": self.sigma, - "embedding_kwargs": self.embedding_kwargs, + "time_emb": self.time_emb, "concatenate_subnet_input": self._concatenate_subnet_input, } return base_config | serialize(config) - def _discretize_time(self, num_steps: int, rho: float = 3.5, **kwargs): + @staticmethod + def _discretize_time(num_steps: int, rho: float = 3.5): t = keras.ops.linspace(0.0, pi / 2, num_steps) times = keras.ops.exp((t - pi / 2) * rho) * pi / 2 times = keras.ops.concatenate([keras.ops.zeros((1,)), times[1:]], axis=0) @@ -307,7 +307,7 @@ def compute_metrics( r = 1.0 # TODO: if consistency distillation training (not supported yet) is unstable, add schedule here def f_teacher(x, t): - o = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training") + o = self._apply_subnet(x, self.time_emb(t), conditions, training=stage == "training") return self.subnet_projector(o) primals = (xt / self.sigma, t) @@ -321,7 +321,7 @@ def f_teacher(x, t): cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt) # calculate output of the network - subnet_out = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training") + subnet_out = self._apply_subnet(xt / self.sigma, self.time_emb(t), conditions, training=stage == "training") student_out = self.subnet_projector(subnet_out) # calculate the tangent