@@ -53,6 +53,7 @@ def build_adapter(
5353 inference_variables : Sequence [str ],
5454 inference_conditions : Sequence [str ] = None ,
5555 summary_variables : Sequence [str ] = None ,
56+ sample_weight : Sequence [str ] = None ,
5657 ) -> Adapter :
5758 """Create an :py:class:`~bayesflow.adapters.Adapter` suited for the approximator.
5859
@@ -64,6 +65,8 @@ def build_adapter(
6465 Names of the inference conditions in the data
6566 summary_variables : Sequence of str, optional
6667 Names of the summary variables in the data
68+ sample_weight : str, optional
69+ Name of the sample weights
6770 """
6871 adapter = Adapter ()
6972 adapter .to_array ()
@@ -77,8 +80,11 @@ def build_adapter(
7780 adapter .as_set (summary_variables )
7881 adapter .concatenate (summary_variables , into = "summary_variables" )
7982
80- adapter .keep (["inference_variables" , "inference_conditions" , "summary_variables" ])
81- adapter .standardize ()
83+ if sample_weight is not None :
84+ adapter = adapter .rename (sample_weight , "sample_weight" )
85+
86+ adapter .keep (["inference_variables" , "inference_conditions" , "summary_variables" , "sample_weight" ])
87+ adapter .standardize (exclude = "sample_weight" )
8288
8389 return adapter
8490
@@ -105,6 +111,7 @@ def compute_metrics(
105111 inference_variables : Tensor ,
106112 inference_conditions : Tensor = None ,
107113 summary_variables : Tensor = None ,
114+ sample_weight : Tensor = None ,
108115 stage : str = "training" ,
109116 ) -> dict [str , Tensor ]:
110117 if self .summary_network is None :
@@ -128,7 +135,7 @@ def compute_metrics(
128135 # Force a conversion to Tensor
129136 inference_variables = keras .tree .map_structure (keras .ops .convert_to_tensor , inference_variables )
130137 inference_metrics = self .inference_network .compute_metrics (
131- inference_variables , conditions = inference_conditions , stage = stage
138+ inference_variables , conditions = inference_conditions , sample_weight = sample_weight , stage = stage
132139 )
133140
134141 loss = inference_metrics .get ("loss" , keras .ops .zeros (())) + summary_metrics .get ("loss" , keras .ops .zeros (()))
0 commit comments