Skip to content

Commit 2229a68

Browse files
Ensures compiled metrics resolve output names correctly (#21694)
* Ensures compiled metrics resolve output names correctly * improves readability
1 parent b062368 commit 2229a68

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

keras/src/trainers/compile_utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(
148148
self.built = False
149149
self.name = "compile_metrics"
150150
self.output_names = output_names
151+
self._resolved_output_names = None
151152

152153
@property
153154
def metrics(self):
@@ -175,10 +176,16 @@ def variables(self):
175176

176177
def build(self, y_true, y_pred):
177178
num_outputs = 1 # default
178-
if self.output_names:
179+
# Resolve output names. If y_pred is a dict, prefer its keys.
180+
if isinstance(y_pred, dict):
181+
keys = sorted(list(y_pred.keys()))
182+
if self.output_names and set(self.output_names) == set(keys):
183+
# If there is a perfect match, use the user-provided order.
184+
output_names = self.output_names
185+
else:
186+
output_names = keys
187+
elif self.output_names:
179188
output_names = self.output_names
180-
elif isinstance(y_pred, dict):
181-
output_names = sorted(list(y_pred.keys()))
182189
elif isinstance(y_pred, (list, tuple)):
183190
num_outputs = len(y_pred)
184191
if all(hasattr(x, "_keras_history") for x in y_pred):
@@ -187,6 +194,7 @@ def build(self, y_true, y_pred):
187194
output_names = None
188195
else:
189196
output_names = None
197+
self._resolved_output_names = output_names
190198
if output_names:
191199
num_outputs = len(output_names)
192200

@@ -316,9 +324,10 @@ def _build_metrics_set(
316324
return flat_metrics
317325

318326
def _flatten_y(self, y):
319-
if isinstance(y, dict) and self.output_names:
327+
names = self._resolved_output_names
328+
if isinstance(y, dict) and names:
320329
result = []
321-
for name in self.output_names:
330+
for name in names:
322331
if name in y:
323332
result.append(y[name])
324333
return result

keras/src/trainers/compile_utils_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,57 @@ def my_custom_metric(y_true, y_pred):
235235
self.assertEqual(len(result), 1)
236236
self.assertTrue("my_custom_metric" in result)
237237

238+
def test_dict_outputs_ignore_mismatched_output_names(self):
239+
"""Tests that when output_names does not match dict keys, the correct
240+
keys are used."""
241+
242+
# output_names represent internal op names that do not match dict keys.
243+
compile_metrics = CompileMetrics(
244+
metrics={
245+
"a": metrics_module.MeanSquaredError(),
246+
"b": metrics_module.MeanSquaredError(),
247+
},
248+
weighted_metrics=None,
249+
output_names=["dense", "dense_1"],
250+
)
251+
252+
# Symbolic build with dict outputs keyed by user-facing names.
253+
y_true = {
254+
"a": backend.KerasTensor((3, 2)),
255+
"b": backend.KerasTensor((3, 2)),
256+
}
257+
y_pred = {
258+
"a": backend.KerasTensor((3, 2)),
259+
"b": backend.KerasTensor((3, 2)),
260+
}
261+
262+
# The build method should correctly map metrics for outputs 'a' and 'b',
263+
# even when the op names do not match.
264+
compile_metrics.build(y_true, y_pred)
265+
266+
# Make the two outputs produce different MSEs to verify mapping.
267+
y_true = {
268+
"a": np.zeros((3, 2), dtype="float32"),
269+
"b": np.zeros((3, 2), dtype="float32"),
270+
}
271+
y_pred = {
272+
# MSE(a) = 0.0
273+
"a": np.zeros((3, 2), dtype="float32"),
274+
# MSE(b) = 1.0
275+
"b": np.ones((3, 2), dtype="float32"),
276+
}
277+
compile_metrics.update_state(y_true, y_pred)
278+
279+
result = compile_metrics.result()
280+
self.assertIsInstance(result, dict)
281+
282+
# Should expose metrics under the dict keys ('a', 'b'),
283+
# and not the internal names.
284+
self.assertIn("a_mean_squared_error", result)
285+
self.assertIn("b_mean_squared_error", result)
286+
self.assertAllClose(result["a_mean_squared_error"], 0.0)
287+
self.assertAllClose(result["b_mean_squared_error"], 1.0, atol=1e-6)
288+
238289

239290
class TestCompileLoss(testing.TestCase):
240291
def test_single_output_case(self):

0 commit comments

Comments
 (0)