@@ -235,6 +235,57 @@ def my_custom_metric(y_true, y_pred):
235
235
self .assertEqual (len (result ), 1 )
236
236
self .assertTrue ("my_custom_metric" in result )
237
237
238
+ def test_dict_outputs_ignore_mismatched_output_names (self ):
239
+ """Tests that when output_names does not match dict keys, the correct
240
+ keys are used."""
241
+
242
+ # output_names represent internal op names that do not match dict keys.
243
+ compile_metrics = CompileMetrics (
244
+ metrics = {
245
+ "a" : metrics_module .MeanSquaredError (),
246
+ "b" : metrics_module .MeanSquaredError (),
247
+ },
248
+ weighted_metrics = None ,
249
+ output_names = ["dense" , "dense_1" ],
250
+ )
251
+
252
+ # Symbolic build with dict outputs keyed by user-facing names.
253
+ y_true = {
254
+ "a" : backend .KerasTensor ((3 , 2 )),
255
+ "b" : backend .KerasTensor ((3 , 2 )),
256
+ }
257
+ y_pred = {
258
+ "a" : backend .KerasTensor ((3 , 2 )),
259
+ "b" : backend .KerasTensor ((3 , 2 )),
260
+ }
261
+
262
+ # The build method should correctly map metrics for outputs 'a' and 'b',
263
+ # even when the op names do not match.
264
+ compile_metrics .build (y_true , y_pred )
265
+
266
+ # Make the two outputs produce different MSEs to verify mapping.
267
+ y_true = {
268
+ "a" : np .zeros ((3 , 2 ), dtype = "float32" ),
269
+ "b" : np .zeros ((3 , 2 ), dtype = "float32" ),
270
+ }
271
+ y_pred = {
272
+ # MSE(a) = 0.0
273
+ "a" : np .zeros ((3 , 2 ), dtype = "float32" ),
274
+ # MSE(b) = 1.0
275
+ "b" : np .ones ((3 , 2 ), dtype = "float32" ),
276
+ }
277
+ compile_metrics .update_state (y_true , y_pred )
278
+
279
+ result = compile_metrics .result ()
280
+ self .assertIsInstance (result , dict )
281
+
282
+ # Should expose metrics under the dict keys ('a', 'b'),
283
+ # and not the internal names.
284
+ self .assertIn ("a_mean_squared_error" , result )
285
+ self .assertIn ("b_mean_squared_error" , result )
286
+ self .assertAllClose (result ["a_mean_squared_error" ], 0.0 )
287
+ self .assertAllClose (result ["b_mean_squared_error" ], 1.0 , atol = 1e-6 )
288
+
238
289
239
290
class TestCompileLoss (testing .TestCase ):
240
291
def test_single_output_case (self ):
0 commit comments