11import keras
22
3+ from math import prod
4+
5+ from collections .abc import Callable
6+
7+ from bayesflow .utils import keras_kwargs , find_network
38from bayesflow .types import Shape , Tensor
9+ from bayesflow .scoring_rules import ScoringRule
10+
11+ # TODO:
12+ # * [ ] weight initialization
13+ # * [ ] serializable ?
14+ # * [ ] testing
15+ # * [ ] docstrings
416
517
618class PointInferenceNetwork (keras .Layer ):
7- def __init__ (self , ** kwargs ):
8- super ().__init__ (** kwargs )
19+ def __init__ (
20+ self ,
21+ scoring_rules : dict [str , ScoringRule ],
22+ body_subnet : str | type = "mlp" , # naming: shared_subnet / body / subnet ?
23+ heads_subnet : dict [str , str | keras .Layer ] = None , # TODO: `type` instead of `keras.Layer` ? Too specific ?
24+ activations : dict [str , keras .layers .Activation | Callable | str ] = None ,
25+ ** kwargs ,
26+ ):
27+ super ().__init__ (
28+ ** keras_kwargs (kwargs )
29+ ) # TODO: need for bf.utils.keras_kwargs in regular InferenceNetwork class? seems to be a bug
30+
31+ self .scoring_rules = scoring_rules
32+ # For now PointInferenceNetwork uses the same scoring rules for all parameters
33+ # To support using different sets of scoring rules for different parameter (blocks),
34+ # we can look into renaming this class to sth like `HeadCollection` and
35+ # handle the split in a higher-level object. (PointApproximator?)
36+
37+ self .body_subnet = find_network (body_subnet , ** kwargs .get ("body_subnet_kwargs" , {}))
38+
39+ if heads_subnet :
40+ self .heads = {
41+ key : [find_network (value , ** kwargs .get ("heads_subnet_kwargs" , {}).get (key , {}))]
42+ for key , value in heads_subnet .items ()
43+ }
44+ else :
45+ self .heads = {key : [] for key in self .scoring_rules .keys ()}
46+
47+ if activations :
48+ self .activations = {
49+ key : (value if isinstance (value , keras .layers .Activation ) else keras .layers .Activation (value ))
50+ for key , value in activations .items ()
51+ } # make sure that each value is an Activation object
52+ else :
53+ self .activations = {key : keras .layers .Activation ("linear" ) for key in self .scoring_rules .keys ()}
54+ # TODO: Stefan suggested to call these link functions, decide on this
55+
56+ for key in self .heads .keys ():
57+ self .heads [key ] += [
58+ keras .layers .Dense (units = None ),
59+ keras .layers .Reshape (target_shape = (None ,)),
60+ self .activations [key ],
61+ ]
62+
63+ # TODO: allow key-wise overriding of the default, instead of just complete default or totally custom choices
64+
65+ assert set (self .scoring_rules .keys ()) == set (self .heads .keys ()) == set (self .activations .keys ())
966
1067 def build (self , xz_shape : Shape , conditions_shape : Shape = None ) -> None :
11- pass
68+ # build the shared body network
69+ input_shape = conditions_shape
70+ self .body_subnet .build (input_shape )
71+ body_output_shape = self .body_subnet .compute_output_shape (input_shape )
72+
73+ for key in self .heads .keys ():
74+ # head_output_shape (excluding batch_size) convention is (*prediction_shape, *parameter_block_shape)
75+ prediction_shape = self .scoring_rules [key ].prediction_shape
76+ head_output_shape = prediction_shape + xz_shape [1 :]
77+
78+ # set correct head shape
79+ self .heads [key ][- 3 ].units = prod (head_output_shape )
80+ self .heads [key ][- 2 ].target_shape = head_output_shape
81+
82+ # build head block by block
83+ input_shape = body_output_shape
84+ for head_block in self .heads [key ]:
85+ head_block .build (input_shape )
86+ input_shape = head_block .compute_output_shape (input_shape )
1287
1388 def call (
1489 self ,
@@ -17,19 +92,37 @@ def call(
1792 training : bool = False ,
1893 ** kwargs ,
1994 ) -> Tensor | tuple [Tensor , Tensor ]:
95+ # TODO: remove unnecessary simularity with InferenceNetwork
2096 return self ._forward (xz , conditions = conditions , training = training , ** kwargs )
2197
2298 def _forward (
2399 self , x : Tensor , conditions : Tensor = None , training : bool = False , ** kwargs
24100 ) -> Tensor | tuple [Tensor , Tensor ]:
25- raise NotImplementedError
101+ body_output = self .body_subnet (conditions )
102+
103+ output = dict ()
104+ for key , head in self .heads .items ():
105+ y = body_output
106+ for head_block in head :
107+ y = head_block (y )
108+
109+ output |= {key : y }
110+ return output
26111
27112 def compute_metrics (self , x : Tensor , conditions : Tensor = None , stage : str = "training" ) -> dict [str , Tensor ]:
28113 if not self .built :
29114 xz_shape = keras .ops .shape (x )
30115 conditions_shape = None if conditions is None else keras .ops .shape (conditions )
31116 self .build (xz_shape , conditions_shape = conditions_shape )
32117
118+ output = self (x , conditions )
119+
120+ # calculate negative score as mean over all heads
121+ neg_score = 0
122+ for key , rule in self .scoring_rules .items ():
123+ neg_score += rule .score (output [key ], x )
124+ neg_score /= len (self .scoring_rules )
125+
33126 metrics = {}
34127
35128 if stage != "training" and any (self .metrics ):
@@ -41,7 +134,7 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr
41134 pass
42135 # TODO: instead compute estimate based metrics
43136
44- return metrics
137+ return metrics | { "loss" : neg_score }
45138
46139 def estimate (self , conditions : Tensor = None ) -> Tensor :
47140 return self ._forward (None , conditions )
0 commit comments