@@ -27,51 +27,77 @@ def estimate(
2727 if not self .built :
2828 raise AssertionError ("PointApproximator needs to be built before predicting with it." )
2929
30+ # Prepare the input conditions.
31+ conditions = self ._prepare_conditions (conditions , ** kwargs )
32+ # Run the internal estimation and convert the output to numpy.
33+ estimates = self ._run_inference (conditions , ** kwargs )
34+ # Postprocess the inference output with the inverse adapter.
35+ estimates = self ._apply_inverse_adapter (estimates , ** kwargs )
36+ # Optionally split the arrays along the last axis.
37+ if split :
38+ estimates = split_arrays (estimates , axis = - 1 )
39+ # Reorder the nested dictionary so that original variable names are at the top.
40+ estimates = self ._reorder_estimates (estimates )
41+ # Remove unnecessary nesting.
42+ estimates = self ._squeeze_estimates (estimates )
43+
44+ return estimates
45+
46+ def _prepare_conditions (self , conditions : dict [str , np .ndarray ], ** kwargs ) -> dict [str , Tensor ]:
47+ """Adapts and converts the conditions to tensors."""
3048 conditions = self .adapter (conditions , strict = False , stage = "inference" , ** kwargs )
31- conditions = keras .tree .map_structure (keras .ops .convert_to_tensor , conditions )
32- conditions = {"inference_variables" : self ._estimate (** conditions , ** kwargs )}
33- conditions = keras .tree .map_structure (keras .ops .convert_to_numpy , conditions )
34- conditions = {
35- outer_key : {
36- inner_key : self .adapter (
37- dict (inference_variables = conditions ["inference_variables" ][outer_key ][inner_key ]),
49+ return keras .tree .map_structure (keras .ops .convert_to_tensor , conditions )
50+
51+ def _run_inference (self , conditions : dict [str , Tensor ], ** kwargs ) -> dict [str , dict [str , np .ndarray ]]:
52+ """Runs the internal _estimate function and converts the result to numpy arrays."""
53+ # Run the estimation.
54+ inference_output = self ._estimate (** conditions , ** kwargs )
55+ # Wrap the result in a dict and convert to numpy.
56+ wrapped_output = {"inference_variables" : inference_output }
57+ return keras .tree .map_structure (keras .ops .convert_to_numpy , wrapped_output )
58+
59+ def _apply_inverse_adapter (
60+ self , estimates : dict [str , dict [str , np .ndarray ]], ** kwargs
61+ ) -> dict [str , dict [str , dict [str , np .ndarray ]]]:
62+ """Applies the inverse adapter on each inner element of the inference outputs."""
63+ processed = {}
64+ for score_key , score_val in estimates ["inference_variables" ].items ():
65+ processed [score_key ] = {}
66+ for head_key , estimate in score_val .items ():
67+ adapted = self .adapter (
68+ {"inference_variables" : estimate },
3869 inverse = True ,
3970 strict = False ,
4071 ** kwargs ,
4172 )
42- for inner_key in conditions ["inference_variables" ][outer_key ].keys ()
43- }
44- for outer_key in conditions ["inference_variables" ].keys ()
45- }
73+ processed [score_key ][head_key ] = adapted
74+ return processed
4675
47- if split :
48- conditions = split_arrays (conditions , axis = - 1 )
49-
50- # get original variable names to reorder them to highest level
51- inference_variable_names = next (iter (next (iter (conditions .values ())).values ())).keys ()
52-
53- # change ordering of nested dictionary
54- conditions = {
55- variable_name : {
56- outer_key : {
57- inner_key : conditions [outer_key ][inner_key ][variable_name ]
58- for inner_key in conditions [outer_key ].keys ()
59- }
60- for outer_key in conditions .keys ()
61- }
62- for variable_name in inference_variable_names
63- }
64-
65- # remove unnecessary nesting
66- conditions = {
67- variable_name : {
68- outer_key : squeeze_inner_estimates_dict (conditions [variable_name ][outer_key ])
69- for outer_key in conditions [variable_name ].keys ()
70- }
71- for variable_name in conditions .keys ()
72- }
76+ def _reorder_estimates (
77+ self , estimates : dict [str , dict [str , dict [str , np .ndarray ]]]
78+ ) -> dict [str , dict [str , dict [str , np .ndarray ]]]:
79+ """Reorders the nested dictionary so that the inference variable names become the top-level keys."""
80+ # Grab the variable names from one sample inner dictionary.
81+ sample_inner = next (iter (next (iter (estimates .values ())).values ()))
82+ variable_names = sample_inner .keys ()
83+ reordered = {}
84+ for variable in variable_names :
85+ reordered [variable ] = {}
86+ for score_key , inner_dict in estimates .items ():
87+ reordered [variable ][score_key ] = {inner_key : value [variable ] for inner_key , value in inner_dict .items ()}
88+ return reordered
7389
74- return conditions
90+ def _squeeze_estimates (
91+ self , estimates : dict [str , dict [str , dict [str , np .ndarray ]]]
92+ ) -> dict [str , dict [str , np .ndarray ]]:
93+ """Squeezes each inner estimate dictionary to remove unnecessary nesting."""
94+ squeezed = {}
95+ for variable , variable_estimates in estimates .items ():
96+ squeezed [variable ] = {
97+ score_key : squeeze_inner_estimates_dict (inner_estimate )
98+ for score_key , inner_estimate in variable_estimates .items ()
99+ }
100+ return squeezed
75101
76102 def _estimate (
77103 self ,
0 commit comments