@@ -288,9 +288,12 @@ class ResultMetricCollection(dict):
288288 with the same metadata.
289289 """
290290
291- def __init__ (self , * args , metadata : Optional [ _Metadata ] = None ) -> None :
291+ def __init__ (self , * args ) -> None :
292292 super ().__init__ (* args )
293- self .meta = metadata
293+
294+ @property
295+ def meta (self ) -> _Metadata :
296+ return list (self .values ())[0 ].meta
294297
295298 def __getstate__ (self , drop_value : bool = False ) -> dict :
296299 def getstate (item : ResultMetric ) -> dict :
@@ -312,9 +315,6 @@ def setstate(item: dict) -> Union[Dict[str, ResultMetric], ResultMetric, Any]:
312315 items = setstate (state ["items" ])
313316 self .update (items )
314317
315- any_result_metric = next (iter (items .values ()))
316- self .meta = any_result_metric .meta
317-
318318 @classmethod
319319 def _reconstruct (cls , state : dict , sync_fn : Optional [Callable ] = None ) -> "ResultMetricCollection" :
320320 rmc = cls ()
@@ -479,7 +479,7 @@ def fn(v: _METRIC) -> ResultMetric:
479479
480480 value = apply_to_collection (value , (torch .Tensor , Metric ), fn )
481481 if isinstance (value , dict ):
482- value = ResultMetricCollection (value , metadata = meta )
482+ value = ResultMetricCollection (value )
483483 self [key ] = value
484484
485485 def update_metrics (self , key : str , value : _METRIC_COLLECTION ) -> None :
@@ -590,7 +590,6 @@ def extract_batch_size(self, batch: Any) -> None:
590590
591591 def to (self , * args , ** kwargs ) -> "ResultCollection" :
592592 """Move all data to the given device."""
593-
594593 self .update (apply_to_collection (dict (self ), (torch .Tensor , Metric ), move_data_to_device , * args , ** kwargs ))
595594
596595 if self .minimize is not None :
0 commit comments