|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# // SPDX-License-Identifier: BSD |
| 3 | + |
1 | 4 | import json |
2 | 5 | import logging |
3 | | -import re |
4 | | -import subprocess |
5 | 6 | import sys |
6 | | -from functools import lru_cache |
| 7 | +import uuid |
| 8 | +from datetime import datetime, timezone |
| 9 | +from decimal import Decimal |
7 | 10 | from pathlib import Path |
8 | 11 | from time import perf_counter |
9 | | -from typing import Any, List, TypedDict, Union, Optional |
| 12 | +from typing import Any, Union, Optional |
10 | 13 |
|
| 14 | +import boto3 |
| 15 | +import requests |
11 | 16 | import torch |
| 17 | +from botocore.exceptions import ClientError |
12 | 18 | from hydra.experimental.callback import Callback |
13 | 19 | from omegaconf import DictConfig |
14 | 20 |
|
15 | | -_COLLATED_RESULTS_FILENAME = "collated_results.json" |
| 21 | +import s3torchconnector |
| 22 | +from s3torchbenchmarking.constants import ( |
| 23 | + JOB_RESULTS_FILENAME, |
| 24 | + RUN_FILENAME, |
| 25 | + Run, |
| 26 | + EC2Metadata, |
| 27 | + URL_IMDS_TOKEN, |
| 28 | + URL_IMDS_DOCUMENT, |
| 29 | +) |
16 | 30 |
|
17 | 31 | logger = logging.getLogger(__name__) |
18 | 32 |
|
19 | 33 |
|
20 | | -class EC2Metadata(TypedDict): |
21 | | - instance_type: str |
22 | | - placement: str |
23 | | - |
24 | | - |
25 | | -class Metadata(TypedDict): |
26 | | - python_version: str |
27 | | - pytorch_version: str |
28 | | - hydra_version: str |
29 | | - ec2_metadata: Union[EC2Metadata, None] |
30 | | - run_elapsed_time_s: float |
31 | | - number_of_jobs: int |
32 | | - |
33 | | - |
34 | | -class CollatedResults(TypedDict): |
35 | | - metadata: Metadata |
36 | | - results: List[Any] |
| 34 | +class ResultCollatingCallback(Callback): |
| 35 | + """Hydra callback (https://hydra.cc/docs/experimental/callbacks/). |
37 | 36 |
|
| 37 | + Defines some routines to execute when a benchmark run is finished: namely, to merge all job results |
| 38 | + ("job_results.json" files) in one place ("run.json" file), augmented with some metadata. |
| 39 | + """ |
38 | 40 |
|
39 | | -class ResultCollatingCallback(Callback): |
40 | 41 | def __init__(self) -> None: |
41 | | - self._multirun_dir: Optional[Path] = None |
| 42 | + self._multirun_path: Optional[Path] = None |
42 | 43 | self._begin = 0 |
43 | | - self._end = 0 |
44 | 44 |
|
45 | 45 | def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None: |
46 | 46 | self._begin = perf_counter() |
47 | 47 |
|
48 | 48 | def on_job_start(self, config: DictConfig, **kwargs: Any) -> None: |
49 | | - # Runtime variables like the output directory are not available in `on_multirun_end` is called, but are |
50 | | - # available in `on_job_start`, so we collect the path here and refer to it later. |
51 | | - if not self._multirun_dir: |
52 | | - # should be something like "./multirun/2024-11-08/15-47-08/" |
53 | | - self._multirun_dir = Path(config.hydra.runtime.output_dir).parent |
| 49 | + if not self._multirun_path: |
| 50 | + # Hydra variables like `hydra.runtime.output_dir` are not available inside :func:`on_multirun_end`, so we |
| 51 | + # get the information here. |
| 52 | + self._multirun_path = Path(config.hydra.runtime.output_dir).parent |
54 | 53 |
|
55 | 54 | def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None: |
56 | | - self._end = perf_counter() |
57 | | - run_elapsed_time = self._end - self._begin |
| 55 | + run_elapsed_time = perf_counter() - self._begin |
| 56 | + run = self._build_run(config, run_elapsed_time) |
| 57 | + |
| 58 | + self._save_to_disk(run) |
| 59 | + if "dynamodb" in config: |
| 60 | + self._write_to_dynamodb(config.dynamodb.region, config.dynamodb.table, run) |
| 61 | + else: |
| 62 | + logger.info("DynamoDB config not provided: skipping write to table...") |
| 63 | + |
| 64 | + def _build_run(self, config: DictConfig, run_elapsed_time: float) -> Run: |
| 65 | + all_job_results = [] |
| 66 | + for entry in self._multirun_path.glob(f"**/{JOB_RESULTS_FILENAME}"): |
| 67 | + if entry.is_file(): |
| 68 | + all_job_results.append(json.loads(entry.read_text())) |
| 69 | + |
| 70 | + logger.info("Collected %i job results", len(all_job_results)) |
| 71 | + return { |
| 72 | + "run_id": str(uuid.uuid4()), |
| 73 | + "timestamp_utc": datetime.now(timezone.utc).timestamp(), |
| 74 | + "scenario": config.hydra.job.name, |
| 75 | + "versions": { |
| 76 | + "python": sys.version, |
| 77 | + "pytorch": torch.__version__, |
| 78 | + "hydra": config.hydra.runtime.version, |
| 79 | + "s3torchconnector": s3torchconnector.__version__, |
| 80 | + }, |
| 81 | + "ec2_metadata": _get_ec2_metadata(), |
| 82 | + "run_elapsed_time_s": run_elapsed_time, |
| 83 | + "number_of_jobs": len(all_job_results), |
| 84 | + "all_job_results": all_job_results, |
| 85 | + } |
58 | 86 |
|
59 | | - collated_results = self._collate_results(config, run_elapsed_time) |
60 | | - collated_results_path = self._multirun_dir / _COLLATED_RESULTS_FILENAME |
| 87 | + def _save_to_disk(self, run: Run) -> None: |
| 88 | + run_filepath = self._multirun_path / RUN_FILENAME |
61 | 89 |
|
62 | | - logger.info("Saving collated results to: %s", collated_results_path) |
63 | | - with open(collated_results_path, "w") as f: |
64 | | - json.dump(collated_results, f, ensure_ascii=False, indent=4) |
65 | | - logger.info("Collated results saved successfully") |
| 90 | + logger.info("Saving run to: %s", run_filepath) |
| 91 | + with open(run_filepath, "w") as f: |
| 92 | + json.dump(run, f, ensure_ascii=False, indent=2) |
| 93 | + logger.info("Run saved successfully") |
66 | 94 |
|
67 | | - def _collate_results( |
68 | | - self, config: DictConfig, run_elapsed_time: float |
69 | | - ) -> CollatedResults: |
70 | | - collated_results = [] |
71 | | - for file in self._multirun_dir.glob("*/**/result*.json"): |
72 | | - collated_results.append(json.loads(file.read_text())) |
| 95 | + @staticmethod |
| 96 | + def _write_to_dynamodb(region: str, table_name: str, run: Run) -> None: |
| 97 | + dynamodb = boto3.resource("dynamodb", region_name=region) |
| 98 | + table = dynamodb.Table(table_name) |
73 | 99 |
|
74 | | - logger.info("Collated %i result files", len(collated_results)) |
75 | | - return { |
76 | | - "metadata": { |
77 | | - "python_version": sys.version, |
78 | | - "pytorch_version": torch.__version__, |
79 | | - "hydra_version": config.hydra.runtime.version, |
80 | | - "ec2_metadata": get_ec2_metadata(), |
81 | | - "run_elapsed_time_s": run_elapsed_time, |
82 | | - "number_of_jobs": len(collated_results), |
83 | | - }, |
84 | | - "results": collated_results, |
85 | | - } |
| 100 | + # `parse_float=Decimal` is required for DynamoDB (the latter does not work with floats), so we perform that |
| 101 | + # (strange) conversion through dumping then loading again the :class:`Run` object. |
| 102 | + run_json = json.loads(json.dumps(run), parse_float=Decimal) |
86 | 103 |
|
| 104 | + try: |
| 105 | + logger.info("Putting item into table: %s", table_name) |
| 106 | + table.put_item(Item=run_json) |
| 107 | + logger.info("Put item into table successfully") |
| 108 | + except ClientError: |
| 109 | + logger.error("Couldn't put item into table %s", table, exc_info=True) |
87 | 110 |
|
88 | | -@lru_cache |
89 | | -def get_ec2_metadata() -> Union[EC2Metadata, None]: |
90 | | - """Get some EC2 metadata by running the `/opt/aws/bin/ec2-metadata` command. |
91 | 111 |
|
92 | | - The command's output is a single string of text, in a JSON-like format (_but not quite JSON_): hence, its content |
93 | | - is parsed using regex. |
| 112 | +def _get_ec2_metadata() -> Union[EC2Metadata, None]: |
| 113 | + """Get some EC2 metadata. |
94 | 114 |
|
95 | | - The function's call is cached, so we don't execute the command multiple times per runs. |
| 115 | + See also https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html#instancedata-inside-access. |
96 | 116 | """ |
97 | | - result = subprocess.run( |
98 | | - "/opt/aws/bin/ec2-metadata", capture_output=True, text=True, timeout=5 |
| 117 | + token = requests.put( |
| 118 | + URL_IMDS_TOKEN, |
| 119 | + headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, |
| 120 | + timeout=5.0, |
| 121 | + ) |
| 122 | + if token.status_code != 200: |
| 123 | + logger.warning("Failed to get EC2 metadata (acquiring token): %s", token) |
| 124 | + return None |
| 125 | + |
| 126 | + document = requests.get( |
| 127 | + URL_IMDS_DOCUMENT, headers={"X-aws-ec2-metadata-token": token.text}, timeout=5.0 |
99 | 128 | ) |
100 | | - if result.returncode == 0: |
101 | | - metadata = result.stdout |
102 | | - instance_type = re.search("instance-type: (.*)", metadata).group(1) |
103 | | - placement = re.search("placement: (.*)", metadata).group(1) |
104 | | - if instance_type and placement: |
105 | | - return {"instance_type": instance_type, "placement": placement} |
106 | | - return None |
| 129 | + if document.status_code != 200: |
| 130 | + logger.warning("Failed to get EC2 metadata (fetching document): %s", document) |
| 131 | + return None |
| 132 | + |
| 133 | + payload = document.json() |
| 134 | + return { |
| 135 | + "architecture": payload["architecture"], |
| 136 | + "image_id": payload["imageId"], |
| 137 | + "instance_type": payload["instanceType"], |
| 138 | + "region": payload["region"], |
| 139 | + } |
0 commit comments