Skip to content

Commit 547d30e

Browse files
matthieu-d4rIsaevIlya
authored andcommitted
feat(dcp): add benchmark capabilities for DCP (#259)
Create a dedicated `dcp` Python package within `s3torchbenchmarking`, to run benchmarks against fsspec. Use Pandas for result metrics.
1 parent dcb580e commit 547d30e

File tree

9 files changed

+506
-19
lines changed

9 files changed

+506
-19
lines changed

s3torchbenchmarking/conf/dcp.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
s3:
2+
region: ???
3+
uri: ???
4+
epochs: 4
5+
path: ./nvme/ # only used when `checkpoint.storage` contains `disk`, ignored for `s3`
6+
7+
# https://hydra.cc/docs/tutorials/basic/running_your_app/multi-run/#sweeper
8+
hydra:
9+
mode: MULTIRUN
10+
sweeper:
11+
params:
12+
+model: vit-base,T0_3B
13+
+backend: nccl,gloo # nccl == GPU, gloo == CPU
14+
+world_size: 1,2,4,8 # == total number of workers to use
15+
+thread_count: 1,2,4,8
16+
+checkpoint.storage: disk,s3

s3torchbenchmarking/pyproject.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,18 @@ dependencies = [
2727
"click",
2828
"omegaconf",
2929
"accelerate",
30+
"pandas",
3031
]
31-
optional-dependencies = { test = ["pytest"] }
32-
scripts = { s3torch-benchmark = "s3torchbenchmarking.benchmark:run_experiment", s3torch-datagen = "s3torchbenchmarking.datagen:synthesize_dataset" }
32+
33+
[project.optional-dependencies]
34+
test = [
35+
"pytest"
36+
]
37+
38+
[project.scripts]
39+
s3torch-benchmark = "s3torchbenchmarking.benchmark:run_experiment"
40+
s3torch-datagen = "s3torchbenchmarking.datagen:synthesize_dataset"
41+
s3torch-benchmark-dcp = "s3torchbenchmarking.dcp.benchmark:run_benchmark"
3342

3443
[tool.setuptools.packages]
3544
# Pure Python packages/modules and configuration files
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
## PyTorch's Distributed Checkpoint (DCP) benchmarks
2+
3+
The `dcp` Python package holds all the logic to execute benchmarks for [PyTorch's Distributed Checkpointing][DCP]
4+
feature against the `s3torchconnector` library.
5+
6+
### Purpose
7+
8+
These benchmarks are designed to:
9+
10+
1. Test the "save" mechanism of PyTorch DCP (`torch.distributed.checkpoint.save`);
11+
2. Compare the performance of the s3torchconnector library against other libraries and local storage;
12+
3. Measure throughput (in MiB/s) and save times (in seconds).
13+
14+
### Usage
15+
16+
> [!IMPORTANT]
17+
> The benchmarks are designed to be run on a EC2 instance.
18+
19+
Install the `s3torchbenchmarking` package with `pip` (see the [root README](../../../README.md) for instructions); once
20+
installed, the DCP benchmarks can be run with:
21+
22+
```shell
23+
$ s3torch-benchmark-dcp -cd conf -cn dcp
24+
```
25+
26+
The command must be executed from the package's root, where it can read from the `config/` directory; it will create a
27+
`./multirun/` directory (at the location of execution), and store all benchmark results there.
28+
29+
> [!WARNING]
30+
> When saving on local disk, consider clearing the `path` specified in your config between runs to prevent disk space
31+
> issues.
32+
33+
#### Potential caveats
34+
35+
If you encounter the following error during installation:
36+
37+
```
38+
TypeError: canonicalize_version() got an unexpected keyword argument 'strip_trailing_zero'
39+
```
40+
41+
Run this command to resolve it:
42+
43+
```shell
44+
$ pip install "setuptools<71"
45+
```
46+
47+
### Configuration
48+
49+
The benchmark runs can be customized using the [`dcp.yaml`](../../../conf/dcp.yaml) file. This section outlines the key
50+
configuration options and their impacts.
51+
52+
#### Configuration Requirements
53+
54+
All keys in the `dcp.yaml` file must be defined for a run to execute successfully.
55+
56+
#### Key Configuration Options
57+
58+
`epochs`
59+
60+
- Specifies the number of iterations for "saving" a model's checkpoint.
61+
- Note: This does not affect model training, as no actual training occurs in these benchmarks.
62+
63+
`path`
64+
65+
- Designates the directory for benchmark operations.
66+
- If the specified directory doesn't exist, it will be created automatically.
67+
- For optimal performance using an SSD filesystem, refer to the [`prepare_nvme.sh`](../../../utils/prepare_nvme.sh)
68+
script.
69+
70+
`hydra.sweeper.params`
71+
72+
This section allows for multiple benchmark configurations:
73+
74+
- The benchmark will run sequential jobs for each combination of the specified parameters.
75+
- Available options include:
76+
- `+model`: Choose from pre-trained models listed in [`models.py`](models.py).
77+
- `+backend`: Select `nccl`, `gloo`, or both.
78+
- `+world_size`: Defines the number of workers.
79+
- `+thread_count`: Defines the number of threads to use for saving the checkpoints.
80+
- `+checkpoint.storage`: Choose `s3`, `disk`, or both.
81+
82+
#### Example Configuration
83+
84+
```yaml
85+
s3:
86+
region: eu-west-1
87+
uri: s3://my-bucket
88+
epochs: 3
89+
path: ./nvme/
90+
91+
hydra:
92+
mode: MULTIRUN
93+
sweeper:
94+
params:
95+
+model: vit-base,T0_3B
96+
+backend: nccl,gloo
97+
+world_size: 2,4
98+
+thread_count: 1
99+
+checkpoint.storage: s3,disk
100+
```
101+
102+
This configuration will run benchmarks for all combinations of the specified models, backends, world sizes, and storage
103+
options, totaling 16 (2×2×2×1×2) different benchmark scenarios.
104+
105+
### Important notes
106+
107+
- The benchmarks may take some time to complete, depending on the hardware and network configuration.
108+
- For optimal results, it is recommended to run the benchmarks on a dedicated EC2 instance without other
109+
resource-intensive processes.
110+
- Ensure the specified S3 bucket exists in the given region and the EC2 user/role has read+write permissions.
111+
112+
### Results
113+
114+
Benchmark results are organized as follows:
115+
116+
```shell
117+
multirun/
118+
└── YYYY-MM-DD
119+
└── HH-MM-SS
120+
├── 0
121+
│ ├── benchmark.log
122+
│ └── results_small_nccl_2_2_s3.json
123+
├── 1
124+
│ ├── benchmark.log
125+
│ └── results_small_nccl_2_2_disk.json
126+
├── 2
127+
│ ├── benchmark.log
128+
│ └── results_small_nccl_4_2_s3.json
129+
├── 3
130+
│ ├── benchmark.log
131+
│ └── results_small_nccl_4_2_disk.json
132+
└── multirun.yaml
133+
```
134+
135+
Each run creates a timestamped subdirectory. The `./multirun/` directory is managed by [Hydra](https://hydra.cc/).
136+
137+
Result file names reflect the parameter combinations, e.g.,
138+
139+
```
140+
+model: vit-base
141+
+backend: nccl
142+
+world_size: 2
143+
+thread_count: 1
144+
+checkpoint.storage: s3
145+
```
146+
147+
will produce the file `results_vit-base_nccl_2_1_s3.json` (respecting parameters declaration order).
148+
149+
### References
150+
151+
- https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
152+
- https://pytorch.org/docs/stable/elastic/run.html
153+
- https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
154+
155+
[DCP]: https://pytorch.org/docs/stable/distributed.checkpoint.html
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
import logging
5+
import os
6+
import random
7+
import string
8+
from multiprocessing.queues import Queue
9+
from pathlib import Path
10+
from time import perf_counter
11+
from typing import List
12+
13+
import hydra
14+
import torch
15+
import torch.distributed as dist
16+
import torch.distributed.checkpoint as dcp
17+
from omegaconf import DictConfig
18+
from torch import multiprocessing as mp
19+
from torch.distributed.checkpoint import FileSystemWriter
20+
from torch.nn import Module
21+
from torch.nn.parallel import DistributedDataParallel
22+
23+
from s3torchconnector.dcp import S3StorageWriter
24+
from .constants import Timestamps
25+
from .models import get_benchmark_model
26+
from .results import save_results
27+
from ..benchmark_utils import ResourceMonitor
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
@hydra.main(version_base=None, config_path=".", config_name="config")
33+
def run_benchmark(cfg: DictConfig):
34+
"""DCP benchmark entry point."""
35+
benchmark_model = get_benchmark_model(cfg.model)
36+
37+
# For every run, use a randomized suffix (for either local disk or S3).
38+
suffix = "".join(random.choices(string.ascii_letters, k=7))
39+
storage_writer = get_writer(cfg, suffix)
40+
41+
manager = mp.Manager()
42+
corrected_save_timestamps: Queue[Timestamps] = manager.Queue()
43+
processing_timestamps: List[Timestamps] = []
44+
45+
with ResourceMonitor() as monitor:
46+
for epoch in range(cfg.epochs):
47+
logger.info("Executing epoch #%i / %i...", epoch + 1, cfg.epochs)
48+
begin_mp = perf_counter()
49+
mp.spawn(
50+
run,
51+
(cfg, benchmark_model.model, storage_writer, corrected_save_timestamps),
52+
nprocs=cfg.world_size,
53+
join=True,
54+
)
55+
end_mp = perf_counter()
56+
processing_timestamps.append((begin_mp, end_mp))
57+
58+
# Dump the multiprocessing Queue's content into a list.
59+
collector: List[Timestamps] = []
60+
while not corrected_save_timestamps.empty():
61+
collector.append(corrected_save_timestamps.get())
62+
63+
save_results(
64+
cfg,
65+
benchmark_model,
66+
corrected_save_timestamps=collector,
67+
processing_timestamps=processing_timestamps,
68+
)
69+
70+
71+
def get_writer(cfg: DictConfig, suffix: str) -> FileSystemWriter:
72+
"""Instantiate a checkpoint writer based on the input config."""
73+
if cfg.checkpoint.storage == "disk":
74+
local_path = Path(cfg.path) / suffix
75+
logger.info("Saving checkpoint to %s (local disk)...", local_path)
76+
return dcp.FileSystemWriter(local_path, thread_count=cfg.thread_count)
77+
elif cfg.checkpoint.storage == "s3":
78+
uri = build_checkpoint_uri(cfg.s3.uri, suffix)
79+
logger.info("Saving checkpoint to %s (S3)...", uri)
80+
return S3StorageWriter(cfg.s3.region, uri, thread_count=cfg.thread_count)
81+
raise ValueError(f"Storage writer {cfg.checkpoint.storage} not supported")
82+
83+
84+
def build_checkpoint_uri(s3_uri: str, suffix: str) -> str:
85+
suffix = suffix.lstrip("/")
86+
return s3_uri + suffix if s3_uri.endswith("/") else s3_uri + "/" + suffix
87+
88+
89+
def setup(backend: str, world_size: int, rank: int) -> None:
90+
os.environ["MASTER_ADDR"] = "localhost"
91+
os.environ["MASTER_PORT"] = "12355"
92+
dist.init_process_group(backend, world_size=world_size, rank=rank)
93+
94+
95+
# FIXME: configure logging in subprocess accordingly
96+
def run(
97+
rank: int, # needs to be passed first (provided by `multiprocessing.spawn` automatically)
98+
cfg: DictConfig,
99+
model: Module,
100+
storage_writer: FileSystemWriter,
101+
save_timestamps: Queue,
102+
) -> None:
103+
"""Execute the actual code for checkpoint saving.
104+
105+
This function is meant to be executed in subprocesses."""
106+
begin_process = perf_counter()
107+
108+
setup(cfg.backend, world_size=cfg.world_size, rank=rank)
109+
if cfg.backend == "nccl":
110+
device_id = rank % torch.cuda.device_count()
111+
torch.cuda.set_device(device_id)
112+
else:
113+
device_id = rank % torch.cpu.device_count()
114+
torch.cpu.set_device(device_id)
115+
116+
model.to(device_id)
117+
model = DistributedDataParallel(model, device_ids=[device_id])
118+
state_dict = model.state_dict()
119+
120+
begin_save = perf_counter()
121+
dcp.save(state_dict, storage_writer=storage_writer)
122+
end_save = perf_counter()
123+
124+
# Record the save times excluding the influence of the process setup and model loading to device.
125+
save_timestamps.put((begin_process, end_save - (begin_save - begin_process)))
126+
127+
dist.destroy_process_group()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from typing import Tuple
2+
3+
Timestamps = Tuple[float, float]
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
from functools import cached_property
5+
from typing import Callable
6+
7+
from torch.nn import Module
8+
from transformers import AutoModelForSeq2SeqLM, ViTModel, CLIPModel
9+
10+
11+
class BenchmarkModel:
12+
"""Utility class around a :class:`torch.nn.Module`, with an additional metadata layer."""
13+
14+
def __init__(self, loader: Callable, name: str):
15+
self._loader = loader
16+
self._name = name
17+
18+
@property
19+
def name(self) -> str:
20+
return self._name
21+
22+
@cached_property
23+
def model(self) -> Module:
24+
return self._loader(self._name)
25+
26+
@cached_property
27+
def size(self) -> float:
28+
"""Compute a model's size (in MiB).
29+
30+
Sourced from https://discuss.pytorch.org/t/finding-model-size/130275/2.
31+
"""
32+
param_size = 0
33+
for param in self.model.parameters():
34+
param_size += param.nelement() * param.element_size()
35+
buffer_size = 0
36+
for buffer in self.model.buffers():
37+
buffer_size += buffer.nelement() * buffer.element_size()
38+
return (param_size + buffer_size) / 1024**2
39+
40+
41+
# NOTE: keys below are later used to construct a filename, so make sure they do not contain characters that will not
42+
# play well with filesystems (e.g., '/').
43+
SIZE_TO_MODEL = {
44+
# ~350 MB model
45+
"vit-base": BenchmarkModel(
46+
ViTModel.from_pretrained, "google/vit-base-patch16-224-in21k"
47+
),
48+
# ~1.7 GB model
49+
"clip-vit": BenchmarkModel(
50+
CLIPModel.from_pretrained, "openai/clip-vit-large-patch14"
51+
),
52+
# ~12 GB model
53+
"T0_3B": BenchmarkModel(AutoModelForSeq2SeqLM.from_pretrained, "bigscience/T0_3B"),
54+
# ~45 GB model
55+
"T0pp": BenchmarkModel(AutoModelForSeq2SeqLM.from_pretrained, "bigscience/T0pp"),
56+
}
57+
58+
59+
def get_benchmark_model(name: str) -> BenchmarkModel:
60+
"""Select a model for benchmarking."""
61+
if name not in SIZE_TO_MODEL:
62+
raise ValueError(f'Name "{name}" is unexpected')
63+
return SIZE_TO_MODEL[name]

0 commit comments

Comments
 (0)