Skip to content

Commit 80fa349

Browse files
tchatonlexierule
authored andcommitted
[bugfix] Resolve lost reference to meta object in ResultMetricCollection (#8932)
1 parent 403aa4d commit 80fa349

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,12 @@ class ResultMetricCollection(dict):
288288
with the same metadata.
289289
"""
290290

291-
def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None:
291+
def __init__(self, *args) -> None:
292292
super().__init__(*args)
293-
self.meta = metadata
293+
294+
@property
295+
def meta(self) -> _Metadata:
296+
return list(self.values())[0].meta
294297

295298
def __getstate__(self, drop_value: bool = False) -> dict:
296299
def getstate(item: ResultMetric) -> dict:
@@ -312,9 +315,6 @@ def setstate(item: dict) -> Union[Dict[str, ResultMetric], ResultMetric, Any]:
312315
items = setstate(state["items"])
313316
self.update(items)
314317

315-
any_result_metric = next(iter(items.values()))
316-
self.meta = any_result_metric.meta
317-
318318
@classmethod
319319
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "ResultMetricCollection":
320320
rmc = cls()
@@ -479,7 +479,7 @@ def fn(v: _METRIC) -> ResultMetric:
479479

480480
value = apply_to_collection(value, (torch.Tensor, Metric), fn)
481481
if isinstance(value, dict):
482-
value = ResultMetricCollection(value, metadata=meta)
482+
value = ResultMetricCollection(value)
483483
self[key] = value
484484

485485
def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None:
@@ -590,7 +590,6 @@ def extract_batch_size(self, batch: Any) -> None:
590590

591591
def to(self, *args, **kwargs) -> "ResultCollection":
592592
"""Move all data to the given device."""
593-
594593
self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs))
595594

596595
if self.minimize is not None:

tests/trainer/logging_/test_eval_loop_logging.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,3 +590,29 @@ def get_metrics_at_idx(idx):
590590
"test_loss",
591591
}
592592
assert set(results[0]) == {"test_loss"}
593+
594+
595+
def test_logging_dict_on_validation_step(tmpdir):
596+
class TestModel(BoringModel):
597+
def validation_step(self, batch, batch_idx):
598+
loss = super().validation_step(batch, batch_idx)
599+
loss = loss["x"]
600+
metrics = {
601+
"loss": loss,
602+
"loss_1": loss,
603+
}
604+
self.log("val_metrics", metrics)
605+
606+
validation_epoch_end = None
607+
608+
model = TestModel()
609+
610+
trainer = Trainer(
611+
default_root_dir=tmpdir,
612+
limit_train_batches=2,
613+
limit_val_batches=2,
614+
max_epochs=2,
615+
progress_bar_refresh_rate=1,
616+
)
617+
618+
trainer.fit(model)

0 commit comments

Comments
 (0)