@@ -721,7 +721,6 @@ def compositional_score(
721721 # Calculate standard noise schedule components
722722 log_snr_t = expand_right_as (self .noise_schedule .get_log_snr (t = time , training = training ), xz )
723723 log_snr_t = ops .broadcast_to (log_snr_t , ops .shape (xz )[:- 1 ] + (1 ,))
724- alpha_t , sigma_t = self .noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t )
725724
726725 # Compute individual dataset scores
727726 if mini_batch_size is not None and mini_batch_size < n_compositional :
@@ -731,13 +730,13 @@ def compositional_score(
731730 conditions_batch = conditions [:, mini_batch_idx ]
732731 else :
733732 conditions_batch = conditions
734- individual_scores = self ._compute_individual_scores (xz , log_snr_t , alpha_t , sigma_t , conditions_batch , training )
733+ individual_scores = self ._compute_individual_scores (xz , log_snr_t , conditions_batch , training )
735734
736735 # Compute prior score component
737736 prior_score = compute_prior_score (xz )
738737 weighted_prior_score = (1.0 - n_compositional ) * (1.0 - time ) * prior_score
739738
740- # Sum individual scores across compositional dimensiont
739+ # Sum individual scores across compositional dimensions
741740 summed_individual_scores = n_compositional * ops .mean (individual_scores , axis = 1 )
742741
743742 # Combined score using compositional formula: (1-n)(1-t)∇log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ)
@@ -749,8 +748,6 @@ def _compute_individual_scores(
749748 self ,
750749 xz : Tensor ,
751750 log_snr_t : Tensor ,
752- alpha_t : Tensor ,
753- sigma_t : Tensor ,
754751 conditions : Tensor ,
755752 training : bool ,
756753 ) -> Tensor :
@@ -762,9 +759,6 @@ def _compute_individual_scores(
762759 Tensor
763760 Individual scores with shape (n_datasets, n_compositional, ...)
764761 """
765- # Apply subnet to each compositional condition separately
766- transformed_log_snr = self ._transform_log_snr (log_snr_t )
767-
768762 # Get shapes
769763 xz_shape = ops .shape (xz ) # (n_datasets, num_samples, ..., dims)
770764 conditions_shape = ops .shape (conditions ) # (n_datasets, n_compositional, num_samples, ..., dims)
@@ -777,38 +771,21 @@ def _compute_individual_scores(
777771 xz_expanded = ops .expand_dims (xz , axis = 1 ) # (n_datasets, 1, num_samples, ..., dims)
778772 xz_expanded = ops .broadcast_to (xz_expanded , (n_datasets , n_compositional , num_samples ) + dims )
779773
780- # Expand noise schedule components to match compositional structure
781- log_snr_expanded = ops .expand_dims (transformed_log_snr , axis = 1 )
774+ # Expand log_snr_t to match compositional structure
775+ log_snr_expanded = ops .expand_dims (log_snr_t , axis = 1 )
782776 log_snr_expanded = ops .broadcast_to (log_snr_expanded , (n_datasets , n_compositional , num_samples , 1 ))
783777
784- alpha_expanded = ops .expand_dims (alpha_t , axis = 1 )
785- alpha_expanded = ops .broadcast_to (alpha_expanded , (n_datasets , n_compositional , num_samples , 1 ))
786-
787- sigma_expanded = ops .expand_dims (sigma_t , axis = 1 )
788- sigma_expanded = ops .broadcast_to (sigma_expanded , (n_datasets , n_compositional , num_samples , 1 ))
789-
790- # Flatten for subnet application: (n_datasets * n_compositional, num_samples, ..., dims)
778+ # Flatten for score computation: (n_datasets * n_compositional, num_samples, ..., dims)
791779 xz_flat = ops .reshape (xz_expanded , (n_datasets * n_compositional , num_samples ) + dims )
792780 log_snr_flat = ops .reshape (log_snr_expanded , (n_datasets * n_compositional , num_samples , 1 ))
793- alpha_flat = ops .reshape (alpha_expanded , (n_datasets * n_compositional , num_samples , 1 ))
794- sigma_flat = ops .reshape (sigma_expanded , (n_datasets * n_compositional , num_samples , 1 ))
795781 conditions_flat = ops .reshape (conditions , (n_datasets * n_compositional , num_samples ) + conditions_dims )
796782
797- # Apply subnet
798- subnet_out = self ._apply_subnet (xz_flat , log_snr_flat , conditions = conditions_flat , training = training )
799- pred = self .output_projector (subnet_out , training = training )
800-
801- # Convert prediction to x
802- x_pred = self .convert_prediction_to_x (
803- pred = pred , z = xz_flat , alpha_t = alpha_flat , sigma_t = sigma_flat , log_snr_t = log_snr_flat
804- )
805-
806- # Compute score: (α_t * x_pred - z) / σ_t²
807- score = (alpha_flat * x_pred - xz_flat ) / ops .square (sigma_flat )
783+ # Use standard score function
784+ scores_flat = self .score (xz_flat , log_snr_t = log_snr_flat , conditions = conditions_flat , training = training )
808785
809786 # Reshape back to compositional structure
810- score = ops .reshape (score , (n_datasets , n_compositional , num_samples ) + dims )
811- return score
787+ scores = ops .reshape (scores_flat , (n_datasets , n_compositional , num_samples ) + dims )
788+ return scores
812789
813790 def _inverse_compositional (
814791 self ,
0 commit comments