File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -691,7 +691,6 @@ def compositional_sample(
691691
692692 # Prepare prior scores to handle adapter
693693 def compute_prior_score_pre (_samples : Tensor ) -> Tensor :
694- return keras .ops .zeros_like (_samples )
695694 if "inference_variables" in self .standardize :
696695 _samples , log_det_jac_standardize = self .standardize_layers ["inference_variables" ](
697696 _samples , forward = False , log_det_jac = True
@@ -707,7 +706,8 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor:
707706 prior_score [key ] = prior_score [key ]
708707 if len (log_det_jac ) > 0 :
709708 prior_score [key ] += log_det_jac [key ]
710- prior_score [key ] = keras .ops .convert_to_tensor (prior_score [key ])
709+
710+ prior_score = keras .tree .map_structure (keras .ops .convert_to_tensor , prior_score )
711711 # make a tensor
712712 out = keras .ops .concatenate (
713713 list (prior_score .values ()), axis = - 1
You can’t perform that action at this time.
0 commit comments