Skip to content

Commit a7667c1

Browse files
authored
refactor(benchmarks): improve jobs and runs saves, add write to DynamoDB (#280)
Create a "constants.py" file. Improve Hydra's callback behaviour and save run results to DynamoDB (optionally). Improve shell scripts used to run the benchmarks. Add a script to prepare an EC2 instance for benchmarking.
1 parent 1109d79 commit a7667c1

File tree

15 files changed

+318
-136
lines changed

15 files changed

+318
-136
lines changed

s3torchbenchmarking/conf/dcp.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ path: ???
1111
epochs: 4
1212

1313
hydra:
14+
job:
15+
name: dcp
1416
mode: MULTIRUN
17+
sweep:
18+
dir: multirun/dcp/${now:%Y-%m-%d_%H-%M-%S}
1519
sweeper:
1620
params:
1721
+model: vit-base, T0_3B

s3torchbenchmarking/conf/lightning_checkpointing.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@ epochs: 5
1212
save_one_in: 1
1313

1414
hydra:
15+
job:
16+
name: lightning_checkpointing
1517
mode: MULTIRUN
18+
sweep:
19+
dir: multirun/lightning/${now:%Y-%m-%d_%H-%M-%S}
1620
sweeper:
1721
params:
1822
+model: vit-base, whisper, clip-vit, T0_3B, T0pp

s3torchbenchmarking/pyproject.toml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ dependencies = [
2525
"boto3",
2626
"prefixed",
2727
"click",
28-
"omegaconf",
2928
"accelerate",
3029
"pandas",
30+
"requests",
3131
]
3232

3333
[project.optional-dependencies]
@@ -38,9 +38,3 @@ test = [
3838
[project.scripts]
3939
s3torch-benchmark = "s3torchbenchmarking.benchmark:run_experiment"
4040
s3torch-datagen = "s3torchbenchmarking.datagen:synthesize_dataset"
41-
s3torch-benchmark-lightning = "s3torchbenchmarking.lightning_checkpointing.benchmark:run_benchmark"
42-
s3torch-benchmark-dcp = "s3torchbenchmarking.dcp.benchmark:run_benchmark"
43-
44-
[tool.setuptools.packages]
45-
# Pure Python packages/modules and configuration files
46-
find = { where = ["src"] }
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
from typing import TypedDict, Union, Any, List
5+
6+
JOB_RESULTS_FILENAME = "job_results.json"
7+
RUN_FILENAME = "run.json"
8+
9+
# URLs for EC2 metadata retrieval (IMDSv2)
10+
URL_IMDS_TOKEN = "http://169.254.169.254/latest/api/token"
11+
URL_IMDS_DOCUMENT = "http://169.254.169.254/latest/dynamic/instance-identity/document"
12+
13+
14+
class Versions(TypedDict):
15+
python: str
16+
pytorch: str
17+
hydra: str
18+
s3torchconnector: str
19+
20+
21+
class EC2Metadata(TypedDict):
22+
architecture: str
23+
image_id: str
24+
instance_type: str
25+
region: str
26+
27+
28+
class Run(TypedDict):
29+
"""Information about a Hydra run.
30+
31+
Also, a :class:`Run` object will be inserted as-is in DynamoDB."""
32+
33+
run_id: str # PK (Partition Key)
34+
timestamp_utc: float # SK (Sort Key)
35+
scenario: str
36+
versions: Versions
37+
ec2_metadata: Union[EC2Metadata, None]
38+
run_elapsed_time_s: float
39+
number_of_jobs: int
40+
all_job_results: List[Any]

s3torchbenchmarking/src/s3torchbenchmarking/dcp/benchmark.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@
1919
from torch.nn import Module
2020
from torch.nn.parallel import DistributedDataParallel
2121

22+
from s3torchbenchmarking.benchmark_utils import (
23+
build_random_suffix,
24+
build_checkpoint_uri,
25+
)
26+
from s3torchbenchmarking.job_results import save_job_results
27+
from s3torchbenchmarking.models import get_benchmark_model
2228
from s3torchconnector.dcp import S3StorageWriter
23-
from ..benchmark_utils import build_random_suffix, build_checkpoint_uri
24-
from ..job_results import save_job_results
25-
from ..models import get_benchmark_model
2629

2730
Timestamps = Tuple[float, float]
2831
logger = logging.getLogger(__name__)
@@ -129,3 +132,7 @@ def run(
129132
save_timestamps.put((begin_process, end_save - (begin_save - begin_process)))
130133

131134
dist.destroy_process_group()
135+
136+
137+
if __name__ == "__main__":
138+
run_benchmark()
Lines changed: 104 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,139 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
14
import json
25
import logging
3-
import re
4-
import subprocess
56
import sys
6-
from functools import lru_cache
7+
import uuid
8+
from datetime import datetime, timezone
9+
from decimal import Decimal
710
from pathlib import Path
811
from time import perf_counter
9-
from typing import Any, List, TypedDict, Union, Optional
12+
from typing import Any, Union, Optional
1013

14+
import boto3
15+
import requests
1116
import torch
17+
from botocore.exceptions import ClientError
1218
from hydra.experimental.callback import Callback
1319
from omegaconf import DictConfig
1420

15-
_COLLATED_RESULTS_FILENAME = "collated_results.json"
21+
import s3torchconnector
22+
from s3torchbenchmarking.constants import (
23+
JOB_RESULTS_FILENAME,
24+
RUN_FILENAME,
25+
Run,
26+
EC2Metadata,
27+
URL_IMDS_TOKEN,
28+
URL_IMDS_DOCUMENT,
29+
)
1630

1731
logger = logging.getLogger(__name__)
1832

1933

20-
class EC2Metadata(TypedDict):
21-
instance_type: str
22-
placement: str
23-
24-
25-
class Metadata(TypedDict):
26-
python_version: str
27-
pytorch_version: str
28-
hydra_version: str
29-
ec2_metadata: Union[EC2Metadata, None]
30-
run_elapsed_time_s: float
31-
number_of_jobs: int
32-
33-
34-
class CollatedResults(TypedDict):
35-
metadata: Metadata
36-
results: List[Any]
34+
class ResultCollatingCallback(Callback):
35+
"""Hydra callback (https://hydra.cc/docs/experimental/callbacks/).
3736
37+
Defines some routines to execute when a benchmark run is finished: namely, to merge all job results
38+
("job_results.json" files) in one place ("run.json" file), augmented with some metadata.
39+
"""
3840

39-
class ResultCollatingCallback(Callback):
4041
def __init__(self) -> None:
41-
self._multirun_dir: Optional[Path] = None
42+
self._multirun_path: Optional[Path] = None
4243
self._begin = 0
43-
self._end = 0
4444

4545
def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
4646
self._begin = perf_counter()
4747

4848
def on_job_start(self, config: DictConfig, **kwargs: Any) -> None:
49-
# Runtime variables like the output directory are not available in `on_multirun_end` is called, but are
50-
# available in `on_job_start`, so we collect the path here and refer to it later.
51-
if not self._multirun_dir:
52-
# should be something like "./multirun/2024-11-08/15-47-08/"
53-
self._multirun_dir = Path(config.hydra.runtime.output_dir).parent
49+
if not self._multirun_path:
50+
# Hydra variables like `hydra.runtime.output_dir` are not available inside :func:`on_multirun_end`, so we
51+
# get the information here.
52+
self._multirun_path = Path(config.hydra.runtime.output_dir).parent
5453

5554
def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
56-
self._end = perf_counter()
57-
run_elapsed_time = self._end - self._begin
55+
run_elapsed_time = perf_counter() - self._begin
56+
run = self._build_run(config, run_elapsed_time)
57+
58+
self._save_to_disk(run)
59+
if "dynamodb" in config:
60+
self._write_to_dynamodb(config.dynamodb.region, config.dynamodb.table, run)
61+
else:
62+
logger.info("DynamoDB config not provided: skipping write to table...")
63+
64+
def _build_run(self, config: DictConfig, run_elapsed_time: float) -> Run:
65+
all_job_results = []
66+
for entry in self._multirun_path.glob(f"**/{JOB_RESULTS_FILENAME}"):
67+
if entry.is_file():
68+
all_job_results.append(json.loads(entry.read_text()))
69+
70+
logger.info("Collected %i job results", len(all_job_results))
71+
return {
72+
"run_id": str(uuid.uuid4()),
73+
"timestamp_utc": datetime.now(timezone.utc).timestamp(),
74+
"scenario": config.hydra.job.name,
75+
"versions": {
76+
"python": sys.version,
77+
"pytorch": torch.__version__,
78+
"hydra": config.hydra.runtime.version,
79+
"s3torchconnector": s3torchconnector.__version__,
80+
},
81+
"ec2_metadata": _get_ec2_metadata(),
82+
"run_elapsed_time_s": run_elapsed_time,
83+
"number_of_jobs": len(all_job_results),
84+
"all_job_results": all_job_results,
85+
}
5886

59-
collated_results = self._collate_results(config, run_elapsed_time)
60-
collated_results_path = self._multirun_dir / _COLLATED_RESULTS_FILENAME
87+
def _save_to_disk(self, run: Run) -> None:
88+
run_filepath = self._multirun_path / RUN_FILENAME
6189

62-
logger.info("Saving collated results to: %s", collated_results_path)
63-
with open(collated_results_path, "w") as f:
64-
json.dump(collated_results, f, ensure_ascii=False, indent=4)
65-
logger.info("Collated results saved successfully")
90+
logger.info("Saving run to: %s", run_filepath)
91+
with open(run_filepath, "w") as f:
92+
json.dump(run, f, ensure_ascii=False, indent=2)
93+
logger.info("Run saved successfully")
6694

67-
def _collate_results(
68-
self, config: DictConfig, run_elapsed_time: float
69-
) -> CollatedResults:
70-
collated_results = []
71-
for file in self._multirun_dir.glob("*/**/result*.json"):
72-
collated_results.append(json.loads(file.read_text()))
95+
@staticmethod
96+
def _write_to_dynamodb(region: str, table_name: str, run: Run) -> None:
97+
dynamodb = boto3.resource("dynamodb", region_name=region)
98+
table = dynamodb.Table(table_name)
7399

74-
logger.info("Collated %i result files", len(collated_results))
75-
return {
76-
"metadata": {
77-
"python_version": sys.version,
78-
"pytorch_version": torch.__version__,
79-
"hydra_version": config.hydra.runtime.version,
80-
"ec2_metadata": get_ec2_metadata(),
81-
"run_elapsed_time_s": run_elapsed_time,
82-
"number_of_jobs": len(collated_results),
83-
},
84-
"results": collated_results,
85-
}
100+
# `parse_float=Decimal` is required for DynamoDB (the latter does not work with floats), so we perform that
101+
# (strange) conversion through dumping then loading again the :class:`Run` object.
102+
run_json = json.loads(json.dumps(run), parse_float=Decimal)
86103

104+
try:
105+
logger.info("Putting item into table: %s", table_name)
106+
table.put_item(Item=run_json)
107+
logger.info("Put item into table successfully")
108+
except ClientError:
109+
logger.error("Couldn't put item into table %s", table, exc_info=True)
87110

88-
@lru_cache
89-
def get_ec2_metadata() -> Union[EC2Metadata, None]:
90-
"""Get some EC2 metadata by running the `/opt/aws/bin/ec2-metadata` command.
91111

92-
The command's output is a single string of text, in a JSON-like format (_but not quite JSON_): hence, its content
93-
is parsed using regex.
112+
def _get_ec2_metadata() -> Union[EC2Metadata, None]:
113+
"""Get some EC2 metadata.
94114
95-
The function's call is cached, so we don't execute the command multiple times per runs.
115+
See also https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html#instancedata-inside-access.
96116
"""
97-
result = subprocess.run(
98-
"/opt/aws/bin/ec2-metadata", capture_output=True, text=True, timeout=5
117+
token = requests.put(
118+
URL_IMDS_TOKEN,
119+
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
120+
timeout=5.0,
121+
)
122+
if token.status_code != 200:
123+
logger.warning("Failed to get EC2 metadata (acquiring token): %s", token)
124+
return None
125+
126+
document = requests.get(
127+
URL_IMDS_DOCUMENT, headers={"X-aws-ec2-metadata-token": token.text}, timeout=5.0
99128
)
100-
if result.returncode == 0:
101-
metadata = result.stdout
102-
instance_type = re.search("instance-type: (.*)", metadata).group(1)
103-
placement = re.search("placement: (.*)", metadata).group(1)
104-
if instance_type and placement:
105-
return {"instance_type": instance_type, "placement": placement}
106-
return None
129+
if document.status_code != 200:
130+
logger.warning("Failed to get EC2 metadata (fetching document): %s", document)
131+
return None
132+
133+
payload = document.json()
134+
return {
135+
"architecture": payload["architecture"],
136+
"image_id": payload["imageId"],
137+
"instance_type": payload["instanceType"],
138+
"region": payload["region"],
139+
}

s3torchbenchmarking/src/s3torchbenchmarking/job_results.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,14 @@
99
from hydra.core.hydra_config import HydraConfig
1010
from omegaconf import DictConfig, OmegaConf
1111

12-
from .models import BenchmarkModel
12+
from s3torchbenchmarking.constants import JOB_RESULTS_FILENAME
13+
from s3torchbenchmarking.models import BenchmarkModel
1314

1415
logger = logging.getLogger(__name__)
1516

1617

17-
def save_job_results(
18-
cfg: DictConfig,
19-
model: BenchmarkModel,
20-
metrics: Any,
21-
):
22-
"""Save a Hydra job results to a local JSON file."""
23-
18+
def save_job_results(cfg: DictConfig, model: BenchmarkModel, metrics: Any) -> None:
19+
"""Save a single Hydra job results to a JSON file."""
2420
results = {
2521
"model": {
2622
"name": model.name,
@@ -30,19 +26,10 @@ def save_job_results(
3026
"metrics": metrics,
3127
}
3228

33-
tasks = HydraConfig.get().overrides.task
34-
35-
# extract only sweeper values (i.e., ones starting with '+')
36-
tasks = [task for task in tasks if task.startswith("+")]
37-
# turn ["foo=4", "bar=small", "baz=1"] into "4_small_1"
38-
suffix = "_".join([task.split("=")[-1] for task in tasks]) if tasks else ""
39-
40-
# Save the results in the corresponding Hydra job directory (e.g., multirun/2024-11-08/15-47-08/0/<filename>.json).
41-
results_filename = f"results{'_' + suffix if suffix else ''}.json"
4229
results_dir = HydraConfig.get().runtime.output_dir
43-
results_path = Path(results_dir, results_filename)
30+
results_path = Path(results_dir, JOB_RESULTS_FILENAME)
4431

4532
logger.info("Saving job results to: %s", results_path)
4633
with open(results_path, "w") as f:
47-
json.dump(results, f, ensure_ascii=False, indent=4)
34+
json.dump(results, f, ensure_ascii=False, indent=2)
4835
logger.info("Job results saved successfully")

s3torchbenchmarking/src/s3torchbenchmarking/lightning_checkpointing/benchmark.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@
1313
from torch.utils.data import DataLoader
1414
from torchdata.datapipes.iter import IterableWrapper # type: ignore
1515

16-
from s3torchconnector.lightning import S3LightningCheckpoint
17-
from .checkpoint_profiler import CheckpointProfiler
18-
from ..benchmark_utils import (
16+
from s3torchbenchmarking.benchmark_utils import (
1917
ResourceMonitor,
2018
build_checkpoint_uri,
2119
build_random_suffix,
2220
)
23-
from ..job_results import save_job_results
24-
from ..models import get_benchmark_model, LightningAdapter
21+
from s3torchbenchmarking.job_results import save_job_results
22+
from s3torchbenchmarking.lightning_checkpointing.checkpoint_profiler import (
23+
CheckpointProfiler,
24+
)
25+
from s3torchbenchmarking.models import get_benchmark_model, LightningAdapter
26+
from s3torchconnector.lightning import S3LightningCheckpoint
2527

2628
logger = logging.getLogger(__name__)
2729

@@ -70,3 +72,7 @@ def run_benchmark(config: DictConfig):
7072
}
7173

7274
save_job_results(config, benchmark_model, metrics)
75+
76+
77+
if __name__ == "__main__":
78+
run_benchmark()

0 commit comments

Comments
 (0)