Skip to content

Commit 1bb85cf

Browse files
authored
Changed the specification of MeanMetricWrapper (#21569)
1 parent 21cd5c0 commit 1bb85cf

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

keras/src/metrics/reduction_metrics.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,9 @@ def __init__(self, fn, name=None, dtype=None, **kwargs):
201201
def update_state(self, y_true, y_pred, sample_weight=None):
202202
mask = backend.get_keras_mask(y_pred)
203203
values = self._fn(y_true, y_pred, **self._fn_kwargs)
204-
if sample_weight is not None and mask is not None:
205-
sample_weight = losses.loss.apply_mask(
206-
sample_weight, mask, dtype=self.dtype, reduction="sum"
207-
)
204+
sample_weight = losses.loss.apply_mask(
205+
sample_weight, mask, dtype=self.dtype, reduction="sum"
206+
)
208207
return super().update_state(values, sample_weight=sample_weight)
209208

210209
def get_config(self):

0 commit comments

Comments
 (0)