Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
34a1a3f
Keep only cocomap-related changes
sadra-barikbin May 17, 2023
24fe980
Some improvements
sadra-barikbin May 28, 2023
e2ac8ee
Update docs
sadra-barikbin May 28, 2023
e4683de
Merge branch 'master' into cocomap
sadra-barikbin May 28, 2023
7cf53e1
Fix a bug in docs
sadra-barikbin May 29, 2023
4aa9c5d
Fix a tiny bug related to allgather
sadra-barikbin Jun 15, 2023
950c388
Fix a few bugs
sadra-barikbin Jun 16, 2023
9f5f796
Redesign code:
sadra-barikbin Jun 16, 2023
ffb1ba4
Merge branch 'master' into cocomap
sadra-barikbin Jun 16, 2023
65cdd08
Remove all_gather with different shape
sadra-barikbin Jun 17, 2023
e54af52
Merge branch 'master' into cocomap
sadra-barikbin Jun 21, 2023
aac2e55
Add test for all_gather_with_different_shape func
sadra-barikbin Jun 21, 2023
4cf3972
Merge branch 'master' into cocomap
vfdev-5 Jun 21, 2023
6070e18
A few improvements
sadra-barikbin Aug 23, 2023
aa83e60
Merge remote-tracking branch 'upstream/cocomap' into cocomap
sadra-barikbin Aug 23, 2023
deebbde
Add an output transform
sadra-barikbin Aug 31, 2023
62ca5fb
Add a test for the output_transform
sadra-barikbin Aug 31, 2023
418fcf4
Remove 'flavor' because all DeciAI, Ultralytics, Detectron and pycoco…
sadra-barikbin Sep 1, 2023
5fea0cd
Merge branch 'master' into cocomap
sadra-barikbin Sep 1, 2023
79fa1e2
Revert Metric change and a few bug fix
sadra-barikbin Sep 10, 2023
26c96b8
A tiny improvement in local variable names
sadra-barikbin Sep 15, 2023
d18f793
Merge branch 'master' into cocomap
sadra-barikbin Sep 15, 2023
a361ca8
Add max_dep and area_range
sadra-barikbin Dec 4, 2023
ce48583
some improvements
sadra-barikbin Jun 28, 2024
cf02dc0
Improvement in code
sadra-barikbin Jul 11, 2024
1593dfb
Some improvements
sadra-barikbin Jul 12, 2024
e425e12
Fix a bug; Some improvements; Improve docs
sadra-barikbin Jul 16, 2024
bb15f0f
Fix metrics.rst
sadra-barikbin Jul 16, 2024
a184ba5
Merge branch 'master' into cocomap
sadra-barikbin Jul 16, 2024
6fcc97f
Remove @override which is for 3.12
sadra-barikbin Jul 16, 2024
120c755
Fix mypy issues
sadra-barikbin Jul 16, 2024
7c26d08
Fix two tests
sadra-barikbin Jul 16, 2024
c3c4a82
Fix a typo in tests
sadra-barikbin Jul 16, 2024
2405937
Fix dist tests
sadra-barikbin Jul 16, 2024
9b3100d
Merge branch 'master' into cocomap
sadra-barikbin Sep 3, 2024
356f618
Add common obj. det. metrics
sadra-barikbin Sep 3, 2024
bbfc4c7
Merge branch 'master' into cocomap
sadra-barikbin Sep 3, 2024
cb6a328
Change an annotation for the sake of M1 python3.8
sadra-barikbin Sep 3, 2024
248fe89
Use if check on torch.double usages for MPS backend
sadra-barikbin Sep 3, 2024
8bfb802
Fix a typo
sadra-barikbin Sep 4, 2024
4038c2b
Fix a bug related to tensors on same devices
sadra-barikbin Sep 4, 2024
4b6afdd
Fix a bug related to MPS and torch.double
sadra-barikbin Sep 4, 2024
d0e82b3
Fix a bug related to MPS
sadra-barikbin Sep 4, 2024
085e0df
Fix a bug related to MPS
sadra-barikbin Sep 4, 2024
3658f95
Fix a bug related to MPS
sadra-barikbin Sep 4, 2024
0444933
Resolve MPS's lack of cummax
sadra-barikbin Sep 4, 2024
c433718
Revert MPS fallback
sadra-barikbin Sep 4, 2024
dacf407
Apply comments
sadra-barikbin Sep 4, 2024
67454c3
Merge branch 'master' into cocomap
sadra-barikbin Sep 5, 2024
67e38c4
Revert unnecessary changes
sadra-barikbin Sep 5, 2024
978791b
Merge branch 'master' into cocomap
vfdev-5 Sep 9, 2024
7b43c69
Apply review comments
sadra-barikbin Sep 20, 2024
4d3fc57
Merge branch 'master' into cocomap
sadra-barikbin Sep 20, 2024
479d1b7
Merge branch 'master' into cocomap
sadra-barikbin Sep 29, 2024
954d130
Skip MPS on test_integraion as well
sadra-barikbin Sep 29, 2024
d2978cf
Merge branch 'master' into cocomap
sadra-barikbin Oct 5, 2024
ade0737
Merge branch 'master' into cocomap
vfdev-5 Feb 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,13 @@ Complete list of metrics
Frequency
Loss
MeanAbsoluteError
MeanAveragePrecision
MeanPairwiseDistance
MeanSquaredError
metric.Metric
metrics_lambda.MetricsLambda
MultiLabelConfusionMatrix
ObjectDetectionMAP
precision.Precision
PSNR
recall.Recall
Expand Down
66 changes: 59 additions & 7 deletions ignite/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import itertools
import socket
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union

import torch

Expand Down Expand Up @@ -350,29 +351,80 @@ def all_reduce(
return _model.all_reduce(tensor, op, group=group)


def _all_gather_tensors_with_shapes(
tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None
) -> List[torch.Tensor]:
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

if isinstance(group, list) and all(isinstance(item, int) for item in group):
group = _model.new_group(group)

if isinstance(_model, _SerialModel) or (group is not None and _model.get_rank() not in group):
return [tensor]

max_shape = torch.tensor(shapes).amax(dim=1)
padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist()
padded_tensor = torch.nn.functional.pad(
tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes))))
)
all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group) # .split(max_shape[0], dim=0)
return [
all_padded_tensors[
[
slice(rank * max_shape[0] if dim == 0 else 0, rank * max_shape[0] + dim_size if dim == 0 else dim_size)
for dim, dim_size in enumerate(shape)
]
]
for rank, shape in enumerate(shapes)
if group is None or rank in group
]


def all_gather(
tensor: Union[torch.Tensor, float, str], group: Optional[Union[Any, List[int]]] = None
) -> Union[torch.Tensor, float, List[float], List[str]]:
tensor: Union[torch.Tensor, float, str],
group: Optional[Union[Any, List[int]]] = None,
tensor_different_shape: bool = False,
) -> Union[torch.Tensor, float, List[float], List[str], List[torch.Tensor]]:
"""Helper method to perform all gather operation.

Args:
tensor: tensor or number or str to collect across participating processes.
tensor: tensor or number or str to collect across participating processes. If tensor, it should have
the same number of dimensions across processes.
group: list of integer or the process group for each backend. If None, the default process group will be used.
tensor_different_shape: If True, it accounts for difference in input shape across processes. In this case, it
induces more collective operations. If False, `tensor` should have the same shape across processes.
Ignored when `tensor` is not a tensor. Default False.


Returns:
torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)`` if input is a tensor or
torch.Tensor of shape ``(world_size, )`` if input is a number or
List of strings if input is a string
If input is a tensor, returns a torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)``
if ``tensor_different_shape = False``, otherwise a list of tensors with length ``world_size``(if ``group``
is `None`) or `len(group)`. If current process does not belong to `group`, a list with `tensor` as its only
item is retured.
If input is a number, a torch.Tensor of shape ``(world_size, )`` is returned and finally a list of strings
is returned if input is a string.

.. versionchanged:: 0.4.11
added ``group``

.. versionchanged:: 0.5.1
added ``tensor_different_shape``
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

if isinstance(group, list) and all(isinstance(item, int) for item in group):
group = _model.new_group(group)

if isinstance(tensor, torch.Tensor) and tensor_different_shape:
if isinstance(_model, _SerialModel) or (group is not None and _model.get_rank() not in group):
return [tensor]
all_shapes: torch.Tensor = _model.all_gather(torch.tensor(tensor.shape), group=group).view(
-1, len(tensor.shape)
)
return _all_gather_tensors_with_shapes(tensor, all_shapes.tolist(), group=group)

return _model.all_gather(tensor, group=group)


Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ignite.metrics.gan.inception_score import InceptionScore
from ignite.metrics.loss import Loss
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
from ignite.metrics.mean_average_precision import MeanAveragePrecision
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
from ignite.metrics.mean_squared_error import MeanSquaredError
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
Expand All @@ -23,6 +24,7 @@
from ignite.metrics.running_average import RunningAverage
from ignite.metrics.ssim import SSIM
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy
from ignite.metrics.vision.object_detection_map import ObjectDetectionMAP

__all__ = [
"Metric",
Expand Down Expand Up @@ -58,4 +60,6 @@
"Rouge",
"RougeN",
"RougeL",
"MeanAveragePrecision",
"ObjectDetectionMAP",
]
Loading