@@ -41,6 +41,7 @@ def build_adapter(
4141 inference_variables : Sequence [str ],
4242 inference_conditions : Sequence [str ] = None ,
4343 summary_variables : Sequence [str ] = None ,
44+ sample_weights : Sequence [str ] = None ,
4445 ) -> Adapter :
4546 adapter = Adapter .create_default (inference_variables )
4647
@@ -50,7 +51,12 @@ def build_adapter(
5051 if summary_variables is not None :
5152 adapter = adapter .as_set (summary_variables ).concatenate (summary_variables , into = "summary_variables" )
5253
53- adapter = adapter .keep (["inference_variables" , "inference_conditions" , "summary_variables" ]).standardize ()
54+ if sample_weights is not None : # we could provide automatic multiplication of different sample weights
55+ adapter = adapter .concatenate (sample_weights , into = "sample_weights" )
56+
57+ adapter = adapter .keep (
58+ ["inference_variables" , "inference_conditions" , "summary_variables" , "sample_weights" ]
59+ ).standardize (exclude = "sample_weights" )
5460
5561 return adapter
5662
@@ -77,6 +83,7 @@ def compute_metrics(
7783 inference_variables : Tensor ,
7884 inference_conditions : Tensor = None ,
7985 summary_variables : Tensor = None ,
86+ sample_weights : Tensor = None ,
8087 stage : str = "training" ,
8188 ) -> dict [str , Tensor ]:
8289 if self .summary_network is None :
@@ -98,7 +105,7 @@ def compute_metrics(
98105 inference_conditions = keras .ops .concatenate ([inference_conditions , summary_outputs ], axis = - 1 )
99106
100107 inference_metrics = self .inference_network .compute_metrics (
101- inference_variables , conditions = inference_conditions , stage = stage
108+ inference_variables , conditions = inference_conditions , sample_weights = sample_weights , stage = stage
102109 )
103110
104111 loss = inference_metrics .get ("loss" , keras .ops .zeros (())) + summary_metrics .get ("loss" , keras .ops .zeros (()))
0 commit comments