Skip to content

Commit 455f03c

Browse files
committed
fix dtype
1 parent bcb9f60 commit 455f03c

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -603,8 +603,6 @@ def compositional_velocity(
603603

604604
# Get shapes for compositional structure
605605
n_compositional = ops.shape(conditions)[1]
606-
n = ops.cast(n_compositional, dtype=ops.dtype(time))
607-
time_tensor = ops.cast(time, dtype=ops.dtype(xz))
608606

609607
# Calculate standard noise schedule components
610608
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
@@ -628,9 +626,10 @@ def compositional_velocity(
628626
summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1)
629627

630628
# Prior contribution
631-
weighted_prior_score = (1.0 - n) * (1.0 - time_tensor) * prior_score
629+
weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score
632630

633631
# Combined score
632+
time_tensor = ops.cast(time, dtype=ops.dtype(xz))
634633
compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores)
635634

636635
# Compute velocity using standard drift-diffusion formulation

0 commit comments

Comments
 (0)