Skip to content

Commit 2a9b0e1

Browse files
committed
minor fixes
1 parent eac9aaf commit 2a9b0e1

File tree

1 file changed

+9
-32
lines changed

1 file changed

+9
-32
lines changed

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,6 @@ def compositional_score(
721721
# Calculate standard noise schedule components
722722
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
723723
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
724-
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
725724

726725
# Compute individual dataset scores
727726
if mini_batch_size is not None and mini_batch_size < n_compositional:
@@ -731,13 +730,13 @@ def compositional_score(
731730
conditions_batch = conditions[:, mini_batch_idx]
732731
else:
733732
conditions_batch = conditions
734-
individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions_batch, training)
733+
individual_scores = self._compute_individual_scores(xz, log_snr_t, conditions_batch, training)
735734

736735
# Compute prior score component
737736
prior_score = compute_prior_score(xz)
738737
weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score
739738

740-
# Sum individual scores across compositional dimensiont
739+
# Sum individual scores across compositional dimensions
741740
summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1)
742741

743742
# Combined score using compositional formula: (1-n)(1-t)∇log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ)
@@ -749,8 +748,6 @@ def _compute_individual_scores(
749748
self,
750749
xz: Tensor,
751750
log_snr_t: Tensor,
752-
alpha_t: Tensor,
753-
sigma_t: Tensor,
754751
conditions: Tensor,
755752
training: bool,
756753
) -> Tensor:
@@ -762,9 +759,6 @@ def _compute_individual_scores(
762759
Tensor
763760
Individual scores with shape (n_datasets, n_compositional, ...)
764761
"""
765-
# Apply subnet to each compositional condition separately
766-
transformed_log_snr = self._transform_log_snr(log_snr_t)
767-
768762
# Get shapes
769763
xz_shape = ops.shape(xz) # (n_datasets, num_samples, ..., dims)
770764
conditions_shape = ops.shape(conditions) # (n_datasets, n_compositional, num_samples, ..., dims)
@@ -777,38 +771,21 @@ def _compute_individual_scores(
777771
xz_expanded = ops.expand_dims(xz, axis=1) # (n_datasets, 1, num_samples, ..., dims)
778772
xz_expanded = ops.broadcast_to(xz_expanded, (n_datasets, n_compositional, num_samples) + dims)
779773

780-
# Expand noise schedule components to match compositional structure
781-
log_snr_expanded = ops.expand_dims(transformed_log_snr, axis=1)
774+
# Expand log_snr_t to match compositional structure
775+
log_snr_expanded = ops.expand_dims(log_snr_t, axis=1)
782776
log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples, 1))
783777

784-
alpha_expanded = ops.expand_dims(alpha_t, axis=1)
785-
alpha_expanded = ops.broadcast_to(alpha_expanded, (n_datasets, n_compositional, num_samples, 1))
786-
787-
sigma_expanded = ops.expand_dims(sigma_t, axis=1)
788-
sigma_expanded = ops.broadcast_to(sigma_expanded, (n_datasets, n_compositional, num_samples, 1))
789-
790-
# Flatten for subnet application: (n_datasets * n_compositional, num_samples, ..., dims)
778+
# Flatten for score computation: (n_datasets * n_compositional, num_samples, ..., dims)
791779
xz_flat = ops.reshape(xz_expanded, (n_datasets * n_compositional, num_samples) + dims)
792780
log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples, 1))
793-
alpha_flat = ops.reshape(alpha_expanded, (n_datasets * n_compositional, num_samples, 1))
794-
sigma_flat = ops.reshape(sigma_expanded, (n_datasets * n_compositional, num_samples, 1))
795781
conditions_flat = ops.reshape(conditions, (n_datasets * n_compositional, num_samples) + conditions_dims)
796782

797-
# Apply subnet
798-
subnet_out = self._apply_subnet(xz_flat, log_snr_flat, conditions=conditions_flat, training=training)
799-
pred = self.output_projector(subnet_out, training=training)
800-
801-
# Convert prediction to x
802-
x_pred = self.convert_prediction_to_x(
803-
pred=pred, z=xz_flat, alpha_t=alpha_flat, sigma_t=sigma_flat, log_snr_t=log_snr_flat
804-
)
805-
806-
# Compute score: (α_t * x_pred - z) / σ_t²
807-
score = (alpha_flat * x_pred - xz_flat) / ops.square(sigma_flat)
783+
# Use standard score function
784+
scores_flat = self.score(xz_flat, log_snr_t=log_snr_flat, conditions=conditions_flat, training=training)
808785

809786
# Reshape back to compositional structure
810-
score = ops.reshape(score, (n_datasets, n_compositional, num_samples) + dims)
811-
return score
787+
scores = ops.reshape(scores_flat, (n_datasets, n_compositional, num_samples) + dims)
788+
return scores
812789

813790
def _inverse_compositional(
814791
self,

0 commit comments

Comments
 (0)