Skip to content

Commit dc4ee7b

Browse files
committed
fix function name
1 parent 9d9a73c commit dc4ee7b

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def convert_prediction_to_x(
200200
return (z + sigma_t**2 * pred) / alpha_t
201201
raise ValueError(f"Unknown prediction type {self._prediction_type}.")
202202

203-
def _subnet_input(
203+
def _apply_subnet(
204204
self, xz: Tensor, log_snr: Tensor, conditions: Tensor = None, training: bool = False
205205
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
206206
"""
@@ -270,7 +270,7 @@ def velocity(
270270
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
271271
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
272272

273-
subnet_out = self._subnet_input(
273+
subnet_out = self._apply_subnet(
274274
xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
275275
)
276276
pred = self.output_projector(subnet_out, training=training)
@@ -491,7 +491,7 @@ def compute_metrics(
491491
diffused_x = alpha_t * x + sigma_t * eps_t
492492

493493
# calculate output of the network
494-
subnet_out = self._subnet_input(
494+
subnet_out = self._apply_subnet(
495495
diffused_x, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
496496
)
497497
pred = self.output_projector(subnet_out, training=training)

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, training: bool = False,
258258
x = self.consistency_function(x_n, t, conditions=conditions, training=training)
259259
return x
260260

261-
def _subnet_input(
261+
def _apply_subnet(
262262
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
263263
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
264264
"""
@@ -302,7 +302,7 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None,
302302
Whether internal layers (e.g., dropout) should behave in train or inference mode.
303303
"""
304304

305-
subnet_out = self._subnet_input(x, t, conditions, training=training)
305+
subnet_out = self._apply_subnet(x, t, conditions, training=training)
306306
f = self.output_projector(subnet_out)
307307

308308
# Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim)

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def get_config(self):
155155

156156
return base_config | serialize(config)
157157

158-
def _subnet_input(
158+
def _apply_subnet(
159159
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
160160
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
161161
"""
@@ -191,7 +191,7 @@ def velocity(self, xz: Tensor, time: float | Tensor, conditions: Tensor = None,
191191
time = keras.ops.convert_to_tensor(time, dtype=keras.ops.dtype(xz))
192192
time = expand_right_as(time, xz)
193193

194-
subnet_out = self._subnet_input(xz, time, conditions, training=training)
194+
subnet_out = self._apply_subnet(xz, time, conditions, training=training)
195195
return self.output_projector(subnet_out, training=training)
196196

197197
def _velocity_trace(

0 commit comments

Comments
 (0)