Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/environment-update.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ jobs:
run: |
python -m unittest discover tests &&
echo "Running checkpointing tests..." &&
bash ./tests/checkpointing/test_checkpointing.sh
bash ./tests/checkpointing/test_checkpointing.sh &&
echo "Running distributed training tests..." &&
cd tests &&
PYTHONPATH=.. python run_dist_tests.py &&
cd ..
- name: checkout avalanche-docker repo
if: always()
uses: actions/checkout@v3
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,9 @@ jobs:
PYTHONPATH=. python examples/eval_plugin.py &&
echo "Running checkpointing tests..." &&
bash ./tests/checkpointing/test_checkpointing.sh &&
echo "Running distributed training tests..." &&
cd tests &&
PYTHONPATH=.. python run_dist_tests.py &&
cd .. &&
echo "While running unit tests, the following datasets were downloaded:" &&
ls ~/.avalanche/data
7 changes: 5 additions & 2 deletions avalanche/benchmarks/classic/cmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@
)
from avalanche.benchmarks.datasets.external_datasets.mnist import \
get_mnist_dataset
from ..utils import make_classification_dataset, DefaultTransformGroups
from ..utils.data import make_avalanche_dataset
from avalanche.benchmarks.utils import (
make_classification_dataset,
DefaultTransformGroups,
)
from avalanche.benchmarks.utils.data import make_avalanche_dataset

_default_mnist_train_transform = Compose(
[Normalize((0.1307,), (0.3081,))]
Expand Down
52 changes: 42 additions & 10 deletions avalanche/benchmarks/datasets/lvis_dataset/lvis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,20 @@ def __getitem__(self, index):
"""
img_id = self.img_ids[index]
img_dict: LVISImgEntry = self.lvis_api.load_imgs(ids=[img_id])[0]
annotation_dicts = self.targets[index]
annotation_dicts: LVISImgTargets = self.targets[index]

# Transform from LVIS dictionary to torchvision-style target
num_objs = len(annotation_dicts)
num_objs = annotation_dicts["bbox"].shape[0]

boxes = []
labels = []
for i in range(num_objs):
xmin = annotation_dicts[i]["bbox"][0]
ymin = annotation_dicts[i]["bbox"][1]
xmax = xmin + annotation_dicts[i]["bbox"][2]
ymax = ymin + annotation_dicts[i]["bbox"][3]
xmin = annotation_dicts["bbox"][i][0]
ymin = annotation_dicts["bbox"][i][1]
xmax = xmin + annotation_dicts["bbox"][i][2]
ymax = ymin + annotation_dicts["bbox"][i][3]
boxes.append([xmin, ymin, xmax, ymax])
labels.append(annotation_dicts[i]["category_id"])
labels.append(annotation_dicts["category_id"][i])

if len(boxes) > 0:
boxes = torch.as_tensor(boxes, dtype=torch.float32)
Expand All @@ -183,7 +183,7 @@ def __getitem__(self, index):
image_id = torch.tensor([img_id])
areas = []
for i in range(num_objs):
areas.append(annotation_dicts[i]["area"])
areas.append(annotation_dicts["area"][i])
areas = torch.as_tensor(areas, dtype=torch.float32)
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

Expand Down Expand Up @@ -233,7 +233,17 @@ class LVISAnnotationEntry(TypedDict):
category_id: int


class LVISDetectionTargets(Sequence[List[LVISAnnotationEntry]]):
class LVISImgTargets(TypedDict):
id: torch.Tensor
area: torch.Tensor
segmentation: List[List[List[float]]]
image_id: torch.Tensor
bbox: torch.Tensor
category_id: torch.Tensor
labels: torch.Tensor


class LVISDetectionTargets(Sequence[List[LVISImgTargets]]):
def __init__(
self,
lvis_api: LVIS,
Expand All @@ -254,7 +264,28 @@ def __getitem__(self, index):
annotation_dicts: List[LVISAnnotationEntry] = self.lvis_api.load_anns(
annotation_ids
)
return annotation_dicts

n_annotations = len(annotation_dicts)

category_tensor = torch.empty((n_annotations,), dtype=torch.long)
target_dict: LVISImgTargets = {
'bbox': torch.empty((n_annotations, 4), dtype=torch.float32),
'category_id': category_tensor,
'id': torch.empty((n_annotations,), dtype=torch.long),
'area': torch.empty((n_annotations,), dtype=torch.float32),
'image_id': torch.full((1,), img_id, dtype=torch.long),
'segmentation': [],
'labels': category_tensor # Alias of category_id
}

for ann_idx, annotation in enumerate(annotation_dicts):
target_dict['bbox'][ann_idx] = torch.as_tensor(annotation['bbox'])
target_dict['category_id'][ann_idx] = annotation['category_id']
target_dict['id'][ann_idx] = annotation['id']
target_dict['area'][ann_idx] = annotation['area']
target_dict['segmentation'].append(annotation['segmentation'])

return target_dict


def _test_to_tensor(a, b):
Expand Down Expand Up @@ -316,5 +347,6 @@ def _plot_detection_sample(img: Image.Image, target):
"LvisDataset",
"LVISImgEntry",
"LVISAnnotationEntry",
"LVISImgTargets",
"LVISDetectionTargets",
]
44 changes: 22 additions & 22 deletions avalanche/benchmarks/utils/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import torch
from torch.utils.data import RandomSampler, DistributedSampler
from torch.utils.data import RandomSampler, DistributedSampler, Dataset
from torch.utils.data.dataloader import DataLoader

from avalanche.benchmarks.utils.collate_functions import (
Expand All @@ -31,6 +31,7 @@
)
from avalanche.benchmarks.utils.data import AvalancheDataset
from avalanche.benchmarks.utils.data_attribute import DataAttribute
from avalanche.distributed.distributed_helper import DistributedHelper

_default_collate_mbatches_fn = classification_collate_mbatches_fn

Expand Down Expand Up @@ -284,14 +285,14 @@ def __init__(
self.collate_mbatches = collate_mbatches

for data in self.datasets:
if _DistributedHelper.is_distributed and distributed_sampling:
if DistributedHelper.is_distributed and distributed_sampling:
seed = torch.randint(
0,
2 ** 32 - 1 - _DistributedHelper.world_size,
2 ** 32 - 1 - DistributedHelper.world_size,
(1,),
dtype=torch.int64,
)
seed += _DistributedHelper.rank
seed += DistributedHelper.rank
generator = torch.Generator()
generator.manual_seed(int(seed))
else:
Expand Down Expand Up @@ -584,11 +585,11 @@ def _get_batch_sizes(


def _make_data_loader(
dataset,
distributed_sampling,
data_loader_args,
batch_size,
force_no_workers=False,
dataset: Dataset,
distributed_sampling: bool,
data_loader_args: Dict[str, Any],
batch_size: int,
force_no_workers: bool = False,
):
data_loader_args = data_loader_args.copy()

Expand All @@ -601,14 +602,22 @@ def _make_data_loader(
if 'prefetch_factor' in data_loader_args:
data_loader_args['prefetch_factor'] = 2

if _DistributedHelper.is_distributed and distributed_sampling:
if DistributedHelper.is_distributed and distributed_sampling:
# Note: shuffle only goes in the sampler, while
# drop_last must be passed to both the sampler
# and the DataLoader
drop_last = data_loader_args.pop("drop_last", False)
sampler = DistributedSampler(
dataset,
shuffle=data_loader_args.pop("shuffle", False),
drop_last=data_loader_args.pop("drop_last", False),
shuffle=data_loader_args.pop("shuffle", True),
drop_last=drop_last,
)
data_loader = DataLoader(
dataset, sampler=sampler, batch_size=batch_size, **data_loader_args
dataset,
sampler=sampler,
batch_size=batch_size,
drop_last=drop_last,
**data_loader_args
)
else:
sampler = None
Expand All @@ -619,15 +628,6 @@ def _make_data_loader(
return data_loader, sampler


class __DistributedHelperPlaceholder:
is_distributed = False
world_size = 1
rank = 0


_DistributedHelper = __DistributedHelperPlaceholder()


__all__ = [
"detection_collate_fn",
"detection_collate_mbatches_fn",
Expand Down
28 changes: 25 additions & 3 deletions avalanche/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import TypeVar, Generic
from typing import Optional, Type, TypeVar, Generic
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand Down Expand Up @@ -27,8 +27,18 @@ class BasePlugin(Generic[Template], ABC):
and loggers.
"""

def __init__(self):
pass
def __init__(self, supports_distributed: bool = False):
"""
Initializes a strategy plugin.

:param: If True, this plugin instance supports distributed training.
Defaults to false.
"""

self.supports_distributed = supports_distributed
"""
A flag describing whether this plugin supports distributed training
"""

def before_training(self, strategy: Template, *args, **kwargs):
"""Called before `train` by the `BaseTemplate`."""
Expand Down Expand Up @@ -68,6 +78,18 @@ def after_eval(self, strategy: Template, *args, **kwargs) -> CallbackResult:
"""Called after `eval` by the `BaseTemplate`."""
pass

def _check_distributed_support(
self,
distributed_training_param: Optional[bool],
main_class: Type) -> bool:
if distributed_training_param is None:
if self.__class__ == main_class:
return True
else:
return False

return distributed_training_param


class BaseSGDPlugin(BasePlugin[Template], ABC):
"""ABC for BaseSGDTemplate plugins.
Expand Down
1 change: 1 addition & 0 deletions avalanche/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .distributed_helper import *
108 changes: 108 additions & 0 deletions avalanche/distributed/distributed_consistency_verification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import hashlib
import io

from typing import Tuple, TYPE_CHECKING

import torch

from torch.utils.data import DataLoader

if TYPE_CHECKING:
from torch import Tensor
from torch.nn import Module
from avalanche.benchmarks import DatasetScenario
from torch.utils.data import Dataset


def hash_benchmark(benchmark: 'DatasetScenario', *,
hash_engine=None, num_workers=0) -> str:
Comment on lines +17 to +18
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be a class method? (``hash`) same for the other classes in this file, except the classes defined outside of avalanche

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think I can move those elements to the appropriate classes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the only avalanche-specific hash function in that file is hash_benchmark. Do you think it is still appropriate to move it to CLScenario?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to reuse the class __hash__ method if possible so that child classes can safely override its behavior if needed. Also, hash_dataset should work only on AvalancheDataset since we don't really support any other dataset.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alas, __hash__ must return an int. It is designed to provide a coarse mechanism for populating hash maps. I think that we can just move those methods to the CLScenario and AvalancheDataset classes for the moment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, if it's different we can keep it as is. Maybe it will be more clear to me once I see how you use it for distributed training

if hash_engine is None:
hash_engine = hashlib.sha256()

for stream_name in sorted(benchmark.streams.keys()):
stream = benchmark.streams[stream_name]
hash_engine.update(stream_name.encode())
for experience in stream:
exp_dataset = experience.dataset
hash_dataset(exp_dataset,
hash_engine=hash_engine,
num_workers=num_workers)
return hash_engine.hexdigest()


def hash_dataset(dataset: 'Dataset', *, hash_engine=None, num_workers=0) -> str:
if hash_engine is None:
hash_engine = hashlib.sha256()

data_loader = DataLoader(
dataset,
collate_fn=lambda batch: tuple(zip(*batch)),
num_workers=num_workers
)
for loaded_elem in data_loader:
example = tuple(tuple_element[0] for tuple_element in loaded_elem)

# https://stackoverflow.com/a/63880190
buff = io.BytesIO()
torch.save(example, buff)
buff.seek(0)
hash_engine.update(buff.read())
return hash_engine.hexdigest()


def hash_minibatch(minibatch: 'Tuple[Tensor]', *, hash_engine=None) -> str:
if hash_engine is None:
hash_engine = hashlib.sha256()

for tuple_elem in minibatch:
buff = io.BytesIO()
torch.save(tuple_elem, buff)
buff.seek(0)
hash_engine.update(buff.read())
return hash_engine.hexdigest()


def hash_tensor(tensor: 'Tensor', *, hash_engine=None) -> str:
if hash_engine is None:
hash_engine = hashlib.sha256()

buff = io.BytesIO()
torch.save(tensor, buff)
buff.seek(0)
hash_engine.update(buff.read())
return hash_engine.hexdigest()


def hash_model(
model: 'Module',
include_buffers=True,
*,
hash_engine=None) -> str:
if hash_engine is None:
hash_engine = hashlib.sha256()

for name, param in model.named_parameters():
hash_engine.update(name.encode())
buff = io.BytesIO()
torch.save(param.detach().cpu(), buff)
buff.seek(0)
hash_engine.update(buff.read())

if include_buffers:
for name, model_buffer in model.named_buffers():
hash_engine.update(name.encode())
buff = io.BytesIO()
torch.save(model_buffer.detach().cpu(), buff)
buff.seek(0)
hash_engine.update(buff.read())

return hash_engine.hexdigest()


__all__ = [
'hash_benchmark',
'hash_dataset',
'hash_minibatch',
'hash_tensor',
'hash_model'
]
Loading