@@ -18,10 +18,15 @@ def __init__(self, approximators: dict[str, Approximator], **kwargs):
1818
1919 self .num_approximators = len (self .approximators )
2020
21+ def build_from_data (self , adapted_data : dict [str , any ]):
22+ data_shapes = keras .tree .map_structure (keras .ops .shape , adapted_data )
23+ if len (data_shapes ["inference_variables" ]) > 2 :
24+ # Remove the ensemble dimension from data_shapes. This expects data_shapes are the shapes of a
25+ # batch of training data, where the second axis corresponds to different approximators.
26+ data_shapes = {k : v [:1 ] + v [2 :] for k , v in data_shapes .items ()}
27+ self .build (data_shapes )
28+
2129 def build (self , data_shapes : dict [str , tuple [int ] | dict [str , dict ]]) -> None :
22- # Remove the ensemble dimension from data_shapes. This expects data_shapes are the shapes of a
23- # batch of training data, where the second axis corresponds to different approximators.
24- data_shapes = {k : v [:1 ] + v [2 :] for k , v in data_shapes .items ()}
2530 for approximator in self .approximators .values ():
2631 approximator .build (data_shapes )
2732
@@ -82,7 +87,7 @@ def sample(
8287 conditions : Mapping [str , np .ndarray ],
8388 split : bool = False ,
8489 ** kwargs ,
85- ) -> dict [str , np .ndarray ]:
90+ ) -> dict [str , dict [ str , np .ndarray ] ]:
8691 samples = {}
8792 for approx_name , approximator in self .approximators .items ():
8893 if self ._has_obj_method (approximator , "sample" ):
@@ -91,6 +96,25 @@ def sample(
9196 )
9297 return samples
9398
99+ def log_prob (self , data : Mapping [str , np .ndarray ], ** kwargs ) -> dict [str , np .ndarray ]:
100+ log_prob = {}
101+ for approx_name , approximator in self .approximators .items ():
102+ if self ._has_obj_method (approximator , "log_prob" ):
103+ log_prob [approx_name ] = approximator .log_prob (data = data , ** kwargs )
104+ return log_prob
105+
106+ def estimate (
107+ self ,
108+ conditions : Mapping [str , np .ndarray ],
109+ split : bool = False ,
110+ ** kwargs ,
111+ ) -> dict [str , dict [str , dict [str , np .ndarray ]]]:
112+ estimates = {}
113+ for approx_name , approximator in self .approximators .items ():
114+ if self ._has_obj_method (approximator , "estimate" ):
115+ estimates [approx_name ] = approximator .estimate (conditions = conditions , split = split , ** kwargs )
116+ return estimates
117+
94118 def _has_obj_method (self , obj , name ):
95119 method = getattr (obj , name , None )
96120 return callable (method )
0 commit comments