-
-
Notifications
You must be signed in to change notification settings - Fork 313
Add base elements to support distributed comms. Add supports_distributed plugin flag. #1370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
88bcce8
cf1ca14
11bd700
87ccaf6
24962cc
4d4c63a
a419d79
3b71a04
3f5a78a
ee2881d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .distributed_helper import * |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| import hashlib | ||
| import io | ||
|
|
||
| from typing import Tuple, TYPE_CHECKING | ||
|
|
||
| import torch | ||
|
|
||
| from torch.utils.data import DataLoader | ||
|
|
||
| if TYPE_CHECKING: | ||
| from torch import Tensor | ||
| from torch.nn import Module | ||
| from avalanche.benchmarks import DatasetScenario | ||
| from torch.utils.data import Dataset | ||
|
|
||
|
|
||
| def hash_benchmark(benchmark: 'DatasetScenario', *, | ||
| hash_engine=None, num_workers=0) -> str: | ||
|
Comment on lines
+17
to
+18
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be a class method? (``hash`) same for the other classes in this file, except the classes defined outside of avalanche
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think I can move those elements to the appropriate classes.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that the only avalanche-specific hash function in that file is
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's better to reuse the class
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alas,
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, if it's different we can keep it as is. Maybe it will be more clear to me once I see how you use it for distributed training |
||
| if hash_engine is None: | ||
| hash_engine = hashlib.sha256() | ||
|
|
||
| for stream_name in sorted(benchmark.streams.keys()): | ||
| stream = benchmark.streams[stream_name] | ||
| hash_engine.update(stream_name.encode()) | ||
| for experience in stream: | ||
| exp_dataset = experience.dataset | ||
| hash_dataset(exp_dataset, | ||
| hash_engine=hash_engine, | ||
| num_workers=num_workers) | ||
| return hash_engine.hexdigest() | ||
|
|
||
|
|
||
| def hash_dataset(dataset: 'Dataset', *, hash_engine=None, num_workers=0) -> str: | ||
| if hash_engine is None: | ||
| hash_engine = hashlib.sha256() | ||
|
|
||
| data_loader = DataLoader( | ||
| dataset, | ||
| collate_fn=lambda batch: tuple(zip(*batch)), | ||
| num_workers=num_workers | ||
| ) | ||
| for loaded_elem in data_loader: | ||
| example = tuple(tuple_element[0] for tuple_element in loaded_elem) | ||
|
|
||
| # https://stackoverflow.com/a/63880190 | ||
| buff = io.BytesIO() | ||
| torch.save(example, buff) | ||
| buff.seek(0) | ||
| hash_engine.update(buff.read()) | ||
| return hash_engine.hexdigest() | ||
|
|
||
|
|
||
| def hash_minibatch(minibatch: 'Tuple[Tensor]', *, hash_engine=None) -> str: | ||
| if hash_engine is None: | ||
| hash_engine = hashlib.sha256() | ||
|
|
||
| for tuple_elem in minibatch: | ||
| buff = io.BytesIO() | ||
| torch.save(tuple_elem, buff) | ||
| buff.seek(0) | ||
| hash_engine.update(buff.read()) | ||
| return hash_engine.hexdigest() | ||
|
|
||
|
|
||
| def hash_tensor(tensor: 'Tensor', *, hash_engine=None) -> str: | ||
| if hash_engine is None: | ||
| hash_engine = hashlib.sha256() | ||
|
|
||
| buff = io.BytesIO() | ||
| torch.save(tensor, buff) | ||
| buff.seek(0) | ||
| hash_engine.update(buff.read()) | ||
| return hash_engine.hexdigest() | ||
|
|
||
|
|
||
| def hash_model( | ||
| model: 'Module', | ||
| include_buffers=True, | ||
| *, | ||
| hash_engine=None) -> str: | ||
| if hash_engine is None: | ||
| hash_engine = hashlib.sha256() | ||
|
|
||
| for name, param in model.named_parameters(): | ||
| hash_engine.update(name.encode()) | ||
| buff = io.BytesIO() | ||
| torch.save(param.detach().cpu(), buff) | ||
| buff.seek(0) | ||
| hash_engine.update(buff.read()) | ||
|
|
||
| if include_buffers: | ||
| for name, model_buffer in model.named_buffers(): | ||
| hash_engine.update(name.encode()) | ||
| buff = io.BytesIO() | ||
| torch.save(model_buffer.detach().cpu(), buff) | ||
| buff.seek(0) | ||
| hash_engine.update(buff.read()) | ||
|
|
||
| return hash_engine.hexdigest() | ||
|
|
||
|
|
||
| __all__ = [ | ||
| 'hash_benchmark', | ||
| 'hash_dataset', | ||
| 'hash_minibatch', | ||
| 'hash_tensor', | ||
| 'hash_model' | ||
| ] | ||
Uh oh!
There was an error while loading. Please reload this page.