Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def build_adapter(
inference_variables: Sequence[str],
inference_conditions: Sequence[str] = None,
summary_variables: Sequence[str] = None,
sample_weights: Sequence[str] = None,
) -> Adapter:
adapter = Adapter.create_default(inference_variables)

Expand All @@ -50,7 +51,12 @@ def build_adapter(
if summary_variables is not None:
adapter = adapter.as_set(summary_variables).concatenate(summary_variables, into="summary_variables")

adapter = adapter.keep(["inference_variables", "inference_conditions", "summary_variables"]).standardize()
if sample_weights is not None: # we could provide automatic multiplication of different sample weights
adapter = adapter.concatenate(sample_weights, into="sample_weights")

adapter = adapter.keep(
["inference_variables", "inference_conditions", "summary_variables", "sample_weights"]
).standardize(exclude="sample_weights")

return adapter

Expand All @@ -77,6 +83,7 @@ def compute_metrics(
inference_variables: Tensor,
inference_conditions: Tensor = None,
summary_variables: Tensor = None,
sample_weights: Tensor = None,
stage: str = "training",
) -> dict[str, Tensor]:
if self.summary_network is None:
Expand All @@ -98,7 +105,7 @@ def compute_metrics(
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1)

inference_metrics = self.inference_network.compute_metrics(
inference_variables, conditions=inference_conditions, stage=stage
inference_variables, conditions=inference_conditions, sample_weights=sample_weights, stage=stage
)

loss = inference_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(()))
Expand Down
8 changes: 5 additions & 3 deletions bayesflow/networks/coupling_flow/coupling_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,12 @@ def _inverse(

return x

def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
def compute_metrics(
self, x: Tensor, conditions: Tensor = None, sample_weights: Tensor = None, stage: str = "training"
) -> dict[str, Tensor]:
base_metrics = super().compute_metrics(x, conditions=conditions, sample_weights=sample_weights, stage=stage)

z, log_density = self(x, conditions=conditions, inverse=False, density=True)
loss = -keras.ops.mean(log_density)
loss = self.aggregate(-log_density, sample_weights)

return base_metrics | {"loss": loss}
10 changes: 7 additions & 3 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ def deltas(t, xz):
return x

def compute_metrics(
self, x: Tensor | Sequence[Tensor, ...], conditions: Tensor = None, stage: str = "training"
self,
x: Tensor | Sequence[Tensor, ...],
conditions: Tensor = None,
sample_weights: Tensor = None,
stage: str = "training",
) -> dict[str, Tensor]:
if isinstance(x, Sequence):
# already pre-configured
Expand All @@ -208,11 +212,11 @@ def compute_metrics(
x = t * x1 + (1 - t) * x0
target_velocity = x1 - x0

base_metrics = super().compute_metrics(x1, conditions, stage)
base_metrics = super().compute_metrics(x1, conditions, sample_weights, stage)

predicted_velocity = self.velocity(x, t, conditions, training=stage == "training")

loss = self.loss_fn(target_velocity, predicted_velocity)
loss = keras.ops.mean(loss)
loss = self.aggregate(loss, sample_weights)

return base_metrics | {"loss": loss}
10 changes: 7 additions & 3 deletions bayesflow/networks/free_form_flow/free_form_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,10 @@ def _sample_v(self, x):
raise ValueError(f"{self.hutchinson_sampling} is not a valid value for hutchinson_sampling.")
return v

def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
def compute_metrics(
self, x: Tensor, conditions: Tensor = None, sample_weights: Tensor = None, stage: str = "training"
) -> dict[str, Tensor]:
base_metrics = super().compute_metrics(x, conditions=conditions, sample_weights=sample_weights, stage=stage)
# sample random vector
v = self._sample_v(x)

Expand All @@ -204,6 +206,8 @@ def decode(z):
nll = -self.base_distribution.log_prob(z)
maximum_likelihood_loss = nll - surrogate
reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1)
loss = ops.mean(maximum_likelihood_loss + self.beta * reconstruction_loss)

losses = maximum_likelihood_loss + self.beta * reconstruction_loss
loss = self.aggregate(losses, sample_weights)

return base_metrics | {"loss": loss}
11 changes: 10 additions & 1 deletion bayesflow/networks/inference_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> Tens
_, log_density = self(samples, conditions=conditions, inverse=False, density=True, **kwargs)
return log_density

def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
def compute_metrics(
self, x: Tensor, conditions: Tensor = None, sample_weights: Tensor = None, stage: str = "training"
) -> dict[str, Tensor]:
if not self.built:
xz_shape = keras.ops.shape(x)
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
Expand All @@ -62,3 +64,10 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr
metrics[metric.name] = metric(samples, x)

return metrics

def aggregate(self, losses: Tensor, weights: Tensor = None):
if weights is not None:
weighted = losses * weights
else:
weighted = losses
return keras.ops.mean(weighted)