Skip to content

Confusing recommendation to use sync_dist=True even with TorchMetrics #20153

@srprca

Description

@srprca

Bug description

Hello!

When I train and validate a model in a multi-GPU setting (HPC, sbatch job that requests multiple GPUs on a single node), I use self.log(..., sync_dist=True) when logging PyTorch losses, and don't specify any value for sync_dist when logging metrics from TorchMetrics library. However, I still get warnings like

...
It is recommended to use `self.log('val_mean_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
...
It is recommended to use `self.log('val_bg_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.

These specific messages correspond to logging tmc.MulticlassRecall(len(self.task.class_names), average="macro", ignore_index=self.metric_ignore_index) and individual components of tmc.MulticlassRecall(len(self.task.class_names), average="none", ignore_index=self.metric_ignore_index).

Full code listing for metric object definitions and logging is provided in the "reproducing the bug" section.

As I understand from a note here, and from discussion here, one doesn't typically need to explicitly use sync_dist when using TorchMetrics.

I wonder if I still need to enable sync_dist=True as advised in the warnings due to some special case that I am not aware about, or should I follow the docs and keep it as is? In any case, this is probably a bug, either in documentation, or in warning code.

Thank you!

What version are you seeing the problem on?

2.3.0

How to reproduce the bug

self.val_metric_funs = tm.MetricCollection(
                {
                    "cm_normalize_all": tmc.MulticlassConfusionMatrix(
                        len(self.task.class_names),
                        ignore_index=self.metric_ignore_index,
                        normalize="all",
                    ),
                    "recall_average_macro": tmc.MulticlassRecall(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "recall_average_none": tmc.MulticlassRecall(
                        len(self.task.class_names),
                        average="none",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "precision_average_macro": tmc.MulticlassPrecision(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "precision_average_none": tmc.MulticlassPrecision(
                        len(self.task.class_names),
                        average="none",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "f1_average_macro": tmc.MulticlassF1Score(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "f1_average_none": tmc.MulticlassF1Score(
                        len(self.task.class_names),
                        average="none",
                        ignore_index=self.metric_ignore_index,
                    ),
                }
            )
if not sanity_check:
            for metric_name, metric in metrics.items():
                metric_fun = self.val_metric_funs[metric_name]
                metric_name_ = metric_name.split("_")[0]
                if isinstance(metric_fun, tmc.MulticlassConfusionMatrix):
                    for true_class_num in range(metric.shape[0]):
                        true_class_name = self.task.class_names[true_class_num]
                        for pred_class_num in range(metric.shape[1]):
                            pred_class_name = self.task.class_names[pred_class_num]
                            self.log(
                                f"val_true_{true_class_name}_pred_{pred_class_name}_cm",
                                metric[true_class_num, pred_class_num].item(),
                                on_step=False,
                                on_epoch=True,
                                logger=True,
                            )
                elif isinstance(
                    metric_fun,
                    (
                        tmc.MulticlassRecall,
                        tmc.MulticlassPrecision,
                        tmc.MulticlassF1Score,
                    ),
                ):
                    if metric_fun.average == "macro":
                        self.log(
                            f"val_mean_{metric_name_}",
                            metric.item(),
                            on_step=False,
                            on_epoch=True,
                            logger=True,
                        )
                    elif metric_fun.average == "none":
                        for class_num, metric_ in enumerate(metric):
                            class_name = self.task.class_names[class_num]
                            self.log(
                                f"val_{class_name}_{metric_name_}",
                                metric_.item(),
                                on_step=False,
                                on_epoch=True,
                                logger=True,
                            )
                    else:
                        raise NotImplementedError(
                            f"Code for logging metric {metric_name} is not implemented"
                        )
                else:
                    raise NotImplementedError(
                        f"Code for logging metric {metric_name} is not implemented"
                    )

Error messages and logs

...
It is recommended to use `self.log('val_mean_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
...
It is recommended to use `self.log('val_bg_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.

Environment

Current environment
#- PyTorch Lightning Version: 2.3.0
#- PyTorch Version: 2.3.1
#- Python version: 3.11.9
#- OS: Linux
#- CUDA/cuDNN version: 11.8
#- How you installed Lightning: conda-forge

More info

No response

cc @carmocca

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked onloggingRelated to the `LoggerConnector` and `log()`ver: 2.2.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions