@@ -57,7 +57,8 @@ def from_config(cls, config):
5757
5858 return cls (** config )
5959
60- def get_head_shapes_from_target_shape (self , target_shape ):
60+ def get_head_shapes_from_target_shape (self , target_shape : Shape ) -> dict [str , Shape ]:
61+ """Request a dictionary of names and output shapes of required heads from the score."""
6162 raise NotImplementedError
6263
6364 def get_subnet (self , key : str ) -> keras .Layer :
@@ -79,23 +80,86 @@ def get_subnet(self, key: str) -> keras.Layer:
7980 else :
8081 return find_network (self .subnets [key ], ** self .subnets_kwargs .get (key , {}))
8182
82- def get_link (self , key : str ):
83+ def get_link (self , key : str ) -> keras .Layer :
84+ """For a specified key, request a link from network output to estimation target.
85+
86+ Parameters
87+ ----------
88+ key : str
89+ Name of head for which to request a link.
90+
91+ Returns
92+ -------
93+ link : keras.Layer
94+ Activation function linking network output to estimation target.
95+ """
8396 if key not in self .links .keys ():
8497 return keras .layers .Activation ("linear" )
8598 elif isinstance (self .links [key ], str ):
8699 return keras .layers .Activation (self .links [key ])
87100 else :
88101 return self .links [key ]
89102
90- def get_head (self , key : str , shape : Shape ):
103+ def get_head (self , key : str , shape : Shape ) -> keras .Sequential :
104+ """For a specified head key and shape, request corresponding head network.
105+
106+ Parameters
107+ ----------
108+ key : str
109+ Name of head for which to request a link.
110+
111+ Returns
112+ -------
113+ head : keras.Sequential
114+ Head network consisting of a learnable projection, a reshape and a link operation
115+ to parameterize estimates.
116+ """
91117 subnet = self .get_subnet (key )
92118 dense = keras .layers .Dense (units = math .prod (shape ))
93119 reshape = keras .layers .Reshape (target_shape = shape )
94120 link = self .get_link (key )
95121 return keras .Sequential ([subnet , dense , reshape , link ])
96122
97123 def score (self , estimates : dict [str , Tensor ], targets : Tensor , weights : Tensor ) -> Tensor :
98- """Scores a probabilistic estimate based of a distribution based on samples of that distribution."""
124+ """Scores a batch of probabilistic estimates of distributions based on samples
125+ of the corresponding distributions.
126+
127+ Parameters
128+ ----------
129+ estimates : dict[str, Tensor]
130+ Dictionary of estimates.
131+ targets : Tensor
132+ Tensor of samples fromt the true distribution to evaluate the estimates.
133+
134+ Returns
135+ -------
136+ numeric_score : Tensor
137+ Negatively oriented score evaluating the estimates, aggregated for the whole batch.
138+
139+ Examples
140+ --------
141+ The following shows how to score estimates with a ``MeanScore``. All ``ScoringRule`` s follow this pattern,
142+ only differing in the structure of the estimates dictionary.
143+
144+ >>> import keras
145+ ... from bayesflow.scores import MeanScore
146+ >>>
147+ >>> # batch of samples from a normal distribution
148+ >>> samples = keras.random.normal(shape=(100,))
149+ >>>
150+ >>> # batch of uninformed (random) estimates
151+ >>> bad_estimates = {"value": keras.random.uniform((100,))}
152+ >>>
153+ >>> # batch of estimates that are closer to the true mean
154+ >>> better_estimates = {"value": keras.random.normal(stddev=0.1, shape=(100,))}
155+ >>>
156+ >>> # calculate the score
157+ >>> scoring_rule = MeanScore()
158+ >>> scoring_rule.score(bad_estimates, samples)
159+ <tf.Tensor: shape=(), dtype=float32, numpy=1.2243813276290894>
160+ >>> scoring_rule.score(better_estimates, samples)
161+ <tf.Tensor: shape=(), dtype=float32, numpy=1.013983130455017>
162+ """
99163 raise NotImplementedError
100164
101165 def aggregate (self , scores : Tensor , weights : Tensor = None ):
0 commit comments