File tree Expand file tree Collapse file tree 1 file changed +2
-3
lines changed
bayesflow/networks/diffusion_model Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Original file line number Diff line number Diff line change @@ -603,8 +603,6 @@ def compositional_velocity(
603603
604604 # Get shapes for compositional structure
605605 n_compositional = ops .shape (conditions )[1 ]
606- n = ops .cast (n_compositional , dtype = ops .dtype (time ))
607- time_tensor = ops .cast (time , dtype = ops .dtype (xz ))
608606
609607 # Calculate standard noise schedule components
610608 log_snr_t = expand_right_as (self .noise_schedule .get_log_snr (t = time , training = training ), xz )
@@ -628,9 +626,10 @@ def compositional_velocity(
628626 summed_individual_scores = n_compositional * ops .mean (individual_scores , axis = 1 )
629627
630628 # Prior contribution
631- weighted_prior_score = (1.0 - n ) * (1.0 - time_tensor ) * prior_score
629+ weighted_prior_score = (1.0 - n_compositional ) * (1.0 - time ) * prior_score
632630
633631 # Combined score
632+ time_tensor = ops .cast (time , dtype = ops .dtype (xz ))
634633 compositional_score = self .compositional_bridge (time_tensor ) * (weighted_prior_score + summed_individual_scores )
635634
636635 # Compute velocity using standard drift-diffusion formulation
You can’t perform that action at this time.
0 commit comments