@@ -27,12 +27,9 @@ def estimate(
2727 if not self .built :
2828 raise AssertionError ("PointApproximator needs to be built before predicting with it." )
2929
30- # Prepare the input conditions.
3130 conditions = self ._prepare_conditions (conditions , ** kwargs )
32- # Run the internal estimation and convert the output to numpy.
33- estimates = self ._run_inference (conditions , ** kwargs )
34- # Postprocess the inference output with the inverse adapter.
35- estimates = self ._apply_inverse_adapter (estimates , ** kwargs )
31+ estimates = self ._estimate (** conditions , ** kwargs )
32+ estimates = self ._apply_inverse_adapter_to_estimates (estimates , ** kwargs )
3633 # Optionally split the arrays along the last axis.
3734 if split :
3835 estimates = split_arrays (estimates , axis = - 1 )
@@ -43,25 +40,40 @@ def estimate(
4340
4441 return estimates
4542
43+ def sample (
44+ self ,
45+ * ,
46+ num_samples : int ,
47+ conditions : dict [str , np .ndarray ],
48+ split : bool = False ,
49+ ** kwargs ,
50+ ) -> dict [str , np .ndarray ]:
51+ if not self .built :
52+ raise AssertionError ("This model needs to be built before using it for sampling." )
53+
54+ conditions = self ._prepare_conditions (conditions , ** kwargs )
55+ samples = self ._sample (num_samples , ** conditions , ** kwargs )
56+ samples = self ._apply_inverse_adapter_to_samples (samples , ** kwargs )
57+ # Optionally split the arrays along the last axis.
58+ if split :
59+ samples = split_arrays (samples , axis = - 1 )
60+ # Squeeze samples if there's only one key-value pair.
61+ samples = self ._squeeze_samples (samples )
62+
63+ return samples
64+
4665 def _prepare_conditions (self , conditions : dict [str , np .ndarray ], ** kwargs ) -> dict [str , Tensor ]:
4766 """Adapts and converts the conditions to tensors."""
4867 conditions = self .adapter (conditions , strict = False , stage = "inference" , ** kwargs )
4968 return keras .tree .map_structure (keras .ops .convert_to_tensor , conditions )
5069
51- def _run_inference (self , conditions : dict [str , Tensor ], ** kwargs ) -> dict [str , dict [str , np .ndarray ]]:
52- """Runs the internal _estimate function and converts the result to numpy arrays."""
53- # Run the estimation.
54- inference_output = self ._estimate (** conditions , ** kwargs )
55- # Wrap the result in a dict and convert to numpy.
56- wrapped_output = {"inference_variables" : inference_output }
57- return keras .tree .map_structure (keras .ops .convert_to_numpy , wrapped_output )
58-
59- def _apply_inverse_adapter (
60- self , estimates : dict [str , dict [str , np .ndarray ]], ** kwargs
70+ def _apply_inverse_adapter_to_estimates (
71+ self , estimates : dict [str , dict [str , Tensor ]], ** kwargs
6172 ) -> dict [str , dict [str , dict [str , np .ndarray ]]]:
62- """Applies the inverse adapter on each inner element of the inference outputs."""
73+ """Applies the inverse adapter on each inner element of the _estimate output dictionary."""
74+ estimates = keras .tree .map_structure (keras .ops .convert_to_numpy , estimates )
6375 processed = {}
64- for score_key , score_val in estimates [ "inference_variables" ] .items ():
76+ for score_key , score_val in estimates .items ():
6577 processed [score_key ] = {}
6678 for head_key , estimate in score_val .items ():
6779 adapted = self .adapter (
@@ -73,6 +85,21 @@ def _apply_inverse_adapter(
7385 processed [score_key ][head_key ] = adapted
7486 return processed
7587
88+ def _apply_inverse_adapter_to_samples (
89+ self , samples : dict [str , Tensor ], ** kwargs
90+ ) -> dict [str , dict [str , np .ndarray ]]:
91+ """Applies the inverse adapter to a dictionary of samples."""
92+ samples = keras .tree .map_structure (keras .ops .convert_to_numpy , samples )
93+ processed = {}
94+ for score_key , samples in samples .items ():
95+ processed [score_key ] = self .adapter (
96+ {"inference_variables" : samples },
97+ inverse = True ,
98+ strict = False ,
99+ ** kwargs ,
100+ )
101+ return processed
102+
76103 def _reorder_estimates (
77104 self , estimates : dict [str , dict [str , dict [str , np .ndarray ]]]
78105 ) -> dict [str , dict [str , dict [str , np .ndarray ]]]:
@@ -99,6 +126,12 @@ def _squeeze_estimates(
99126 }
100127 return squeezed
101128
129+ def _squeeze_samples (self , samples : dict [str , np .ndarray ]) -> np .ndarray or dict [str , np .ndarray ]:
130+ """Squeezes the samples dictionary to just the value if there is only one key-value pair."""
131+ if len (samples ) == 1 :
132+ return next (iter (samples .values ())) # Extract and return the only item's value
133+ return samples
134+
102135 def _estimate (
103136 self ,
104137 inference_conditions : Tensor = None ,
0 commit comments