@@ -409,18 +409,10 @@ def sample(
409409 dict[str, np.ndarray]
410410 Dictionary containing generated samples with the same keys as `conditions`.
411411 """
412-
413- # Apply adapter transforms to raw simulated / real quantities
414- conditions = self .adapter (conditions , strict = False , stage = "inference" , ** kwargs )
415-
416- # Ensure only keys relevant for sampling are present in the conditions dictionary
412+ # Adapt, optionally standardize and convert conditions to tensor.
413+ conditions = self ._prepare_data (conditions , ** kwargs )
414+ # Remove any superfluous keys, just retain actual conditions. # TODO: is this necessary?
417415 conditions = {k : v for k , v in conditions .items () if k in ContinuousApproximator .CONDITION_KEYS }
418- conditions = keras .tree .map_structure (keras .ops .convert_to_tensor , conditions )
419-
420- # Optionally standardize conditions
421- for key in ContinuousApproximator .CONDITION_KEYS :
422- if key in conditions and key in self .standardize :
423- conditions [key ] = self .standardize_layers [key ](conditions [key ])
424416
425417 # Sample and undo optional standardization
426418 samples = self ._sample (num_samples = num_samples , ** conditions , ** kwargs )
@@ -438,6 +430,51 @@ def sample(
438430 samples = split_arrays (samples , axis = - 1 )
439431 return samples
440432
433+ def _prepare_data (
434+ self , data : Mapping [str , np .ndarray ], log_det_jac : bool = False , ** kwargs
435+ ) -> dict [str , Tensor ] | tuple [dict [str , Tensor ], dict [str , Tensor ]]:
436+ """
437+ Adapts, optionally standardizes, and converts the data to tensors to prepare it for the inference network.
438+
439+ Deals with data that represents only conditions, or only inference_variables or both.
440+ """
441+ # TODO:
442+ # * [ ] better docstring
443+
444+ # Adapt, and optionally keep track of ldj of transformations to inference_variables.
445+ adapted = self .adapter (data , strict = False , stage = "inference" , log_det_jac = log_det_jac , ** kwargs )
446+ if log_det_jac :
447+ data , log_det_jac_adapter = adapted
448+ log_det_jac_inference_variables = log_det_jac_adapter .get ("inference_variables" , 0.0 )
449+ else :
450+ data = adapted
451+
452+ # Optionally standardize conditions, if they are part of data.
453+ conditions = {k : v for k , v in data .items () if k in ContinuousApproximator .CONDITION_KEYS }
454+ for key , value in conditions .items ():
455+ if key in self .standardize and key in data .keys ():
456+ data [key ] = self .standardize_layers [key ](value )
457+
458+ # Optionally standardize inference variables, if they are part of data.
459+ if "inference_variables" in data .keys () and "inference_variables" in self .standardize :
460+ standardized = self .standardize_layers ["inference_variables" ](
461+ data ["inference_variables" ], log_det_jac = log_det_jac
462+ )
463+
464+ # Optionally keep track of appropriate log_det_jac.
465+ if log_det_jac :
466+ data ["inference_variables" ], log_det_std = standardized
467+ log_det_jac_inference_variables += keras .ops .convert_to_numpy (log_det_std )
468+ else :
469+ data ["inference_variables" ] = standardized
470+
471+ # Convert to tensor and return.
472+ data = keras .tree .map_structure (keras .ops .convert_to_tensor , data )
473+ if log_det_jac :
474+ return data , log_det_jac
475+ else :
476+ return data
477+
441478 def _sample (
442479 self ,
443480 num_samples : int ,
@@ -517,24 +554,14 @@ def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
517554 np.ndarray
518555 Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
519556 """
520- data , log_det_jac = self .adapter (data , strict = False , stage = "inference" , log_det_jac = True , ** kwargs )
521- log_det_jac = log_det_jac .get ("inference_variables" , 0.0 )
522-
523- # Optionally standardize conditions
524- for key in ContinuousApproximator .CONDITION_KEYS :
525- if key in data and key in self .standardize :
526- data [key ] = self .standardize_layers [key ](data [key ])
557+ # Adapt, optionally standardize and convert to tensor. Keep track of log_det_jac.
558+ data , log_det_jac = self ._prepare_data (data , log_det_jac = True , ** kwargs )
527559
528- # Optionally standardize inference variables
529- if "inference_variables" in self .standardize :
530- data ["inference_variables" ], log_det_std = self .standardize_layers ["inference_variables" ](
531- data ["inference_variables" ], log_det_jac = True
532- )
533- log_det_jac += keras .ops .convert_to_numpy (log_det_std )
534-
535- data = keras .tree .map_structure (keras .ops .convert_to_tensor , data )
560+ # Pass data to networks and convert back to numpy array.
536561 log_prob = self ._log_prob (** data , ** kwargs )
537562 log_prob = keras .ops .convert_to_numpy (log_prob )
563+
564+ # Change of variables formula.
538565 log_prob = log_prob + log_det_jac
539566
540567 return log_prob
0 commit comments