Skip to content

Commit 093785d

Browse files
committed
Draft implementation of quantile estimation
1 parent fc86d4d commit 093785d

File tree

7 files changed

+294
-1
lines changed

7 files changed

+294
-1
lines changed

bayesflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313

1414
from .workflows import BasicWorkflow
15-
from .approximators import ContinuousApproximator
15+
from .approximators import ContinuousApproximator, ContinuousPointApproximator
1616
from .adapters import Adapter
1717
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
1818
from .simulators import make_simulator
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .approximator import Approximator
22
from .continuous_approximator import ContinuousApproximator
3+
from .continuous_point_approximator import ContinuousPointApproximator
34
from .model_comparison_approximator import ModelComparisonApproximator
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from collections.abc import Sequence
2+
3+
import keras
4+
import numpy as np
5+
from keras.saving import (
6+
deserialize_keras_object as deserialize,
7+
register_keras_serializable as serializable,
8+
serialize_keras_object as serialize,
9+
)
10+
11+
from bayesflow.adapters import Adapter
12+
from bayesflow.networks import PointInferenceNetwork, SummaryNetwork
13+
from bayesflow.types import Tensor
14+
from bayesflow.utils import logging, split_arrays
15+
from .approximator import Approximator
16+
17+
18+
@serializable(package="bayesflow.approximators")
19+
class ContinuousPointApproximator(Approximator):
20+
"""
21+
Defines a workflow for performing fast posterior or likelihood inference.
22+
The distribution is approximated by a point with an feed-forward network and an optional summary network.
23+
"""
24+
25+
def __init__(
26+
self,
27+
*,
28+
adapter: Adapter,
29+
inference_network: PointInferenceNetwork,
30+
summary_network: SummaryNetwork = None,
31+
**kwargs,
32+
):
33+
super().__init__(**kwargs)
34+
self.adapter = adapter
35+
self.inference_network = inference_network
36+
self.summary_network = summary_network
37+
38+
@classmethod
39+
def build_adapter(
40+
cls,
41+
inference_variables: Sequence[str],
42+
inference_conditions: Sequence[str] = None,
43+
summary_variables: Sequence[str] = None,
44+
) -> Adapter:
45+
adapter = Adapter.create_default(inference_variables)
46+
47+
if inference_conditions is not None:
48+
adapter = adapter.concatenate(inference_conditions, into="inference_conditions")
49+
50+
if summary_variables is not None:
51+
adapter = adapter.as_set(summary_variables).concatenate(summary_variables, into="summary_variables")
52+
53+
adapter = adapter.keep(["inference_variables", "inference_conditions", "summary_variables"]).standardize()
54+
55+
return adapter
56+
57+
def compile(
58+
self,
59+
*args,
60+
inference_metrics: Sequence[keras.Metric] = None,
61+
summary_metrics: Sequence[keras.Metric] = None,
62+
**kwargs,
63+
):
64+
if inference_metrics:
65+
self.inference_network._metrics = inference_metrics
66+
67+
if summary_metrics:
68+
if self.summary_network is None:
69+
logging.warning("Ignoring summary metrics because there is no summary network.")
70+
else:
71+
self.summary_network._metrics = summary_metrics
72+
73+
return super().compile(*args, **kwargs)
74+
75+
def compute_metrics(
76+
self,
77+
inference_variables: Tensor,
78+
inference_conditions: Tensor = None,
79+
summary_variables: Tensor = None,
80+
stage: str = "training",
81+
) -> dict[str, Tensor]:
82+
if self.summary_network is None:
83+
if summary_variables is not None:
84+
raise ValueError("Cannot compute summary metrics without a summary network.")
85+
86+
summary_metrics = {}
87+
else:
88+
if summary_variables is None:
89+
raise ValueError("Summary variables are required when a summary network is present.")
90+
91+
summary_metrics = self.summary_network.compute_metrics(summary_variables, stage=stage)
92+
summary_outputs = summary_metrics.pop("outputs")
93+
94+
# append summary outputs to inference conditions
95+
if inference_conditions is None:
96+
inference_conditions = summary_outputs
97+
else:
98+
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1)
99+
100+
inference_metrics = self.inference_network.compute_metrics(
101+
inference_variables, conditions=inference_conditions, stage=stage
102+
)
103+
104+
loss = inference_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(()))
105+
106+
inference_metrics = {f"{key}/inference_{key}": value for key, value in inference_metrics.items()}
107+
summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()}
108+
109+
metrics = {"loss": loss} | inference_metrics | summary_metrics
110+
111+
return metrics
112+
113+
def fit(self, *args, **kwargs):
114+
return super().fit(*args, **kwargs, adapter=self.adapter)
115+
116+
@classmethod
117+
def from_config(cls, config, custom_objects=None):
118+
config["adapter"] = deserialize(config["adapter"], custom_objects=custom_objects)
119+
config["inference_network"] = deserialize(config["inference_network"], custom_objects=custom_objects)
120+
config["summary_network"] = deserialize(config["summary_network"], custom_objects=custom_objects)
121+
122+
return super().from_config(config, custom_objects=custom_objects)
123+
124+
def get_config(self):
125+
base_config = super().get_config()
126+
config = {
127+
"adapter": serialize(self.adapter),
128+
"inference_network": serialize(self.inference_network),
129+
"summary_network": serialize(self.summary_network),
130+
}
131+
132+
return base_config | config
133+
134+
def estimate(
135+
self,
136+
*,
137+
conditions: dict[str, np.ndarray],
138+
split: bool = False,
139+
**kwargs,
140+
) -> dict[str, np.ndarray]:
141+
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
142+
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
143+
conditions = {"inference_variables": self._estimate(**conditions)}
144+
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
145+
conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)
146+
147+
if split:
148+
conditions = split_arrays(conditions, axis=-1)
149+
return conditions
150+
151+
def _estimate(
152+
self,
153+
inference_conditions: Tensor = None,
154+
summary_variables: Tensor = None,
155+
) -> Tensor:
156+
if self.summary_network is None:
157+
if summary_variables is not None:
158+
raise ValueError("Cannot use summary variables without a summary network.")
159+
else:
160+
if summary_variables is None:
161+
raise ValueError("Summary variables are required when a summary network is present.")
162+
163+
summary_outputs = self.summary_network(summary_variables)
164+
165+
if inference_conditions is None:
166+
inference_conditions = summary_outputs
167+
else:
168+
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=1)
169+
170+
return self.inference_network.estimate(conditions=inference_conditions)

bayesflow/networks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from .cif import CIF
22
from .consistency_models import ConsistencyModel, ContinuousConsistencyModel
33
from .coupling_flow import CouplingFlow
4+
from .regressors import QuantileRegressor
45
from .deep_set import DeepSet
56
from .flow_matching import FlowMatching
67
from .free_form_flow import FreeFormFlow
78
from .inference_network import InferenceNetwork
9+
from .point_inference_network import PointInferenceNetwork
810
from .mlp import MLP
911
from .lstnet import LSTNet
1012
from .summary_network import SummaryNetwork
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import keras
2+
3+
from bayesflow.types import Shape, Tensor
4+
5+
6+
class PointInferenceNetwork(keras.Layer):
7+
def __init__(self, **kwargs):
8+
super().__init__(**kwargs)
9+
10+
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
11+
pass
12+
13+
def call(
14+
self,
15+
xz: Tensor,
16+
conditions: Tensor = None,
17+
training: bool = False,
18+
**kwargs,
19+
) -> Tensor | tuple[Tensor, Tensor]:
20+
return self._forward(xz, conditions=conditions, training=training, **kwargs)
21+
22+
def _forward(
23+
self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs
24+
) -> Tensor | tuple[Tensor, Tensor]:
25+
raise NotImplementedError
26+
27+
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
28+
if not self.built:
29+
xz_shape = keras.ops.shape(x)
30+
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
31+
self.build(xz_shape, conditions_shape=conditions_shape)
32+
33+
metrics = {}
34+
35+
if stage != "training" and any(self.metrics):
36+
# compute sample-based metrics
37+
# samples = self.sample((keras.ops.shape(x)[0],), conditions=conditions)
38+
#
39+
# for metric in self.metrics:
40+
# metrics[metric.name] = metric(samples, x)
41+
pass
42+
# TODO: instead compute estimate based metrics
43+
44+
return metrics
45+
46+
def estimate(self, conditions: Tensor = None) -> Tensor:
47+
return self._forward(None, conditions)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .quantile_regressor import QuantileRegressor
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from collections.abc import Sequence
2+
3+
import keras
4+
from keras.saving import register_keras_serializable as serializable
5+
6+
from bayesflow.types import Tensor
7+
from bayesflow.utils import find_network, keras_kwargs
8+
9+
from ..point_inference_network import PointInferenceNetwork
10+
11+
12+
@serializable(package="networks.regressors")
13+
class QuantileRegressor(PointInferenceNetwork):
14+
def __init__(
15+
self,
16+
subnet: str | type = "mlp",
17+
quantile_levels: Sequence[float] = None,
18+
**kwargs,
19+
):
20+
super().__init__(**keras_kwargs(kwargs))
21+
22+
if quantile_levels is not None:
23+
self.quantile_levels = quantile_levels
24+
else:
25+
self.quantile_levels = [0.1, 0.9]
26+
self.quantile_levels = keras.ops.convert_to_tensor(self.quantile_levels)
27+
self.num_quantiles = len(self.quantile_levels) # should we have this shorthand?
28+
# TODO: should we initialize self.num_variables here already? The actual value is assined in build()
29+
30+
self.body = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
31+
self.head = keras.layers.Dense(
32+
units=None, bias_initializer="zeros", kernel_initializer="zeros"
33+
) # TODO: why initialize at zero (taken from consistency_model.py)
34+
35+
# noinspection PyMethodOverriding
36+
def build(
37+
self, xz_shape, conditions_shape=None
38+
): # TODO: seems like conditions_shape should definetely be supplied, change to positional argument?
39+
super().build(xz_shape)
40+
41+
self.num_variables = xz_shape[-1]
42+
input_shape = conditions_shape
43+
self.body.build(input_shape=input_shape)
44+
45+
input_shape = self.body.compute_output_shape(input_shape)
46+
self.head.units = self.num_quantiles * self.num_variables
47+
self.head.build(input_shape=input_shape)
48+
49+
def _forward(
50+
self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs
51+
) -> Tensor | tuple[Tensor, Tensor]:
52+
head_input = self.body(conditions)
53+
pred_quantiles = self.head(head_input) # (batch_shape, num_quantiles * num_variables)
54+
pred_quantiles = keras.ops.reshape(pred_quantiles, (-1, self.num_quantiles, self.num_variables))
55+
# (batch_shape, num_quantiles, num_variables)
56+
57+
return pred_quantiles
58+
59+
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
60+
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
61+
62+
true_value = x
63+
# TODO: keeping like it used to be, but why is do we not set training=(stage=="training") in self.call()
64+
pred_quantiles = self(x, conditions)
65+
pointwise_differance = pred_quantiles - true_value[:, None, :]
66+
67+
loss = pointwise_differance * (
68+
keras.ops.cast(pointwise_differance > 0, float) - self.quantile_levels[None, :, None]
69+
)
70+
loss = keras.ops.mean(loss)
71+
72+
return base_metrics | {"loss": loss}

0 commit comments

Comments
 (0)