-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
Hello!
When I train and validate a model in a multi-GPU setting (HPC, sbatch job that requests multiple GPUs on a single node), I use self.log(..., sync_dist=True) when logging PyTorch losses, and don't specify any value for sync_dist when logging metrics from TorchMetrics library. However, I still get warnings like
...
It is recommended to use `self.log('val_mean_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
...
It is recommended to use `self.log('val_bg_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
These specific messages correspond to logging tmc.MulticlassRecall(len(self.task.class_names), average="macro", ignore_index=self.metric_ignore_index) and individual components of tmc.MulticlassRecall(len(self.task.class_names), average="none", ignore_index=self.metric_ignore_index).
Full code listing for metric object definitions and logging is provided in the "reproducing the bug" section.
As I understand from a note here, and from discussion here, one doesn't typically need to explicitly use sync_dist when using TorchMetrics.
I wonder if I still need to enable sync_dist=True as advised in the warnings due to some special case that I am not aware about, or should I follow the docs and keep it as is? In any case, this is probably a bug, either in documentation, or in warning code.
Thank you!
What version are you seeing the problem on?
2.3.0
How to reproduce the bug
self.val_metric_funs = tm.MetricCollection(
{
"cm_normalize_all": tmc.MulticlassConfusionMatrix(
len(self.task.class_names),
ignore_index=self.metric_ignore_index,
normalize="all",
),
"recall_average_macro": tmc.MulticlassRecall(
len(self.task.class_names),
average="macro",
ignore_index=self.metric_ignore_index,
),
"recall_average_none": tmc.MulticlassRecall(
len(self.task.class_names),
average="none",
ignore_index=self.metric_ignore_index,
),
"precision_average_macro": tmc.MulticlassPrecision(
len(self.task.class_names),
average="macro",
ignore_index=self.metric_ignore_index,
),
"precision_average_none": tmc.MulticlassPrecision(
len(self.task.class_names),
average="none",
ignore_index=self.metric_ignore_index,
),
"f1_average_macro": tmc.MulticlassF1Score(
len(self.task.class_names),
average="macro",
ignore_index=self.metric_ignore_index,
),
"f1_average_none": tmc.MulticlassF1Score(
len(self.task.class_names),
average="none",
ignore_index=self.metric_ignore_index,
),
}
)if not sanity_check:
for metric_name, metric in metrics.items():
metric_fun = self.val_metric_funs[metric_name]
metric_name_ = metric_name.split("_")[0]
if isinstance(metric_fun, tmc.MulticlassConfusionMatrix):
for true_class_num in range(metric.shape[0]):
true_class_name = self.task.class_names[true_class_num]
for pred_class_num in range(metric.shape[1]):
pred_class_name = self.task.class_names[pred_class_num]
self.log(
f"val_true_{true_class_name}_pred_{pred_class_name}_cm",
metric[true_class_num, pred_class_num].item(),
on_step=False,
on_epoch=True,
logger=True,
)
elif isinstance(
metric_fun,
(
tmc.MulticlassRecall,
tmc.MulticlassPrecision,
tmc.MulticlassF1Score,
),
):
if metric_fun.average == "macro":
self.log(
f"val_mean_{metric_name_}",
metric.item(),
on_step=False,
on_epoch=True,
logger=True,
)
elif metric_fun.average == "none":
for class_num, metric_ in enumerate(metric):
class_name = self.task.class_names[class_num]
self.log(
f"val_{class_name}_{metric_name_}",
metric_.item(),
on_step=False,
on_epoch=True,
logger=True,
)
else:
raise NotImplementedError(
f"Code for logging metric {metric_name} is not implemented"
)
else:
raise NotImplementedError(
f"Code for logging metric {metric_name} is not implemented"
)
Error messages and logs
...
It is recommended to use `self.log('val_mean_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
...
It is recommended to use `self.log('val_bg_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
Environment
Current environment
#- PyTorch Lightning Version: 2.3.0
#- PyTorch Version: 2.3.1
#- Python version: 3.11.9
#- OS: Linux
#- CUDA/cuDNN version: 11.8
#- How you installed Lightning: conda-forge
More info
No response
cc @carmocca