Skip to content

Commit 129966e

Browse files
authored
[nightly][ci] Logging to scuba when running in oss (#427)
1 parent dd0efa6 commit 129966e

File tree

7 files changed

+249
-172
lines changed

7 files changed

+249
-172
lines changed

.ci/upload/scribe.py

Lines changed: 18 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -5,131 +5,35 @@
55
import argparse
66
import json
77
import os
8-
import time
8+
import sys
9+
from os.path import abspath, exists
910

10-
from collections import defaultdict
1111

12-
import requests
12+
def setup_tritonbench_cwd():
13+
original_dir = abspath(os.getcwd())
1314

14-
CATEGORY_NAME = "perfpipe_pytorch_user_benchmarks"
15-
BENCHMARK_SCHEMA = {
16-
"int": ["time"],
17-
"normal": [
18-
"benchmark_date",
19-
"unix_user",
20-
"submission_group_id",
21-
"cuda_version",
22-
"device",
23-
"conda_env",
24-
"pytorch_commit",
25-
"triton_commit",
26-
"tritonbench_commit",
27-
"triton_branch",
28-
"pytorch_branch",
29-
"tritonbench_branch",
30-
"triton_commit_time",
31-
"pytorch_commit_time",
32-
"tritonbench_commit_time",
33-
"github_action",
34-
"github_actor",
35-
"github_base_ref",
36-
"github_ref",
37-
"github_ref_protected",
38-
"github_repository",
39-
"github_run_attempt",
40-
"github_run_id",
41-
"github_run_number",
42-
"github_workflow",
43-
"github_workflow_ref",
44-
"github_workflow_sha",
45-
"job_name",
46-
"runner_arch",
47-
"runner_name",
48-
"runner_type",
49-
"runner_os",
50-
"metric_id",
51-
],
52-
"float": ["metric_value"],
53-
}
15+
for tritonbench_dir in (
16+
".",
17+
"../../tritonbench",
18+
):
19+
if exists(tritonbench_dir):
20+
break
5421

22+
if exists(tritonbench_dir):
23+
tritonbench_dir = abspath(tritonbench_dir)
24+
os.chdir(tritonbench_dir)
25+
sys.path.append(tritonbench_dir)
26+
return original_dir
5527

56-
class ScribeUploader:
57-
def __init__(self, category, schema):
58-
self.category = category
59-
self.schema = schema
60-
61-
def _format_message(self, field_dict):
62-
assert "time" in field_dict, "Missing required Scribe field 'time'"
63-
message = defaultdict(dict)
64-
for field, value in field_dict.items():
65-
field = field.lower()
66-
if value is None:
67-
continue
68-
if field in self.schema["normal"]:
69-
message["normal"][field] = str(value)
70-
elif field in self.schema["int"]:
71-
message["int"][field] = int(value)
72-
elif field in self.schema["float"]:
73-
try:
74-
message["float"][field] = float(value)
75-
except ValueError:
76-
# If value error (e.g., "CUDA OOM"), override the field value to 0.0
77-
message["float"][field] = 0.0
78-
else:
79-
raise ValueError(
80-
"Field {} is not currently used, "
81-
"be intentional about adding new fields to schema".format(field)
82-
)
83-
return message
84-
85-
def _upload(self, messages: list):
86-
access_token = os.environ.get("TRITONBENCH_SCRIBE_GRAPHQL_ACCESS_TOKEN")
87-
if not access_token:
88-
raise ValueError("Can't find access token from environment variable")
89-
url = "https://graph.facebook.com/scribe_logs"
90-
r = requests.post(
91-
url,
92-
data={
93-
"access_token": access_token,
94-
"logs": json.dumps(
95-
[
96-
{
97-
"category": self.category,
98-
"message": json.dumps(message),
99-
"line_escape": False,
100-
}
101-
for message in messages
102-
]
103-
),
104-
},
105-
)
106-
print(r.text)
107-
r.raise_for_status()
108-
109-
def post_benchmark_results(self, bm_data):
110-
messages = []
111-
base_message = {
112-
"time": int(time.time()),
113-
}
114-
base_message.update(bm_data["env"])
115-
base_message.update(bm_data["github"])
116-
base_message["submission_group_id"] = f"tritonbench.{bm_data['name']}"
117-
base_message["unix_user"] = "tritonbench_ci"
118-
for metric in bm_data["metrics"]:
119-
msg = base_message.copy()
120-
msg["metric_id"] = metric
121-
msg["metric_value"] = bm_data["metrics"][metric]
122-
formatted_msg = self._format_message(msg)
123-
messages.append(formatted_msg)
124-
self._upload(messages)
12528

29+
setup_tritonbench_cwd()
30+
from tritonbench.utils.scuba_utils import log_benchmark
12631

12732
if __name__ == "__main__":
12833
parser = argparse.ArgumentParser()
12934
parser.add_argument(
13035
"--json", required=True, type=argparse.FileType("r"), help="Userbenchmark json"
13136
)
13237
args = parser.parse_args()
133-
uploader = ScribeUploader(category=CATEGORY_NAME, schema=BENCHMARK_SCHEMA)
13438
benchmark_data = json.load(args.json)
135-
uploader.post_benchmark_results(benchmark_data)
39+
log_benchmark(benchmark_data)

benchmarks/nightly/run.py

Lines changed: 29 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""
22
Tritonbench nightly run, dashboard: https://hud.pytorch.org/tritonbench/commit_view
3+
Run all operators in nightly/autogen.yaml.
34
Requires the operator to support the speedup metric.
45
"""
56

@@ -37,57 +38,6 @@ def setup_tritonbench_cwd():
3738
return original_dir
3839

3940

40-
def reduce(run_timestamp, output_dir, output_files, args):
41-
"""aggregate all op benchmark csvs into json file"""
42-
from tritonbench.utils.gpu_utils import get_nvidia_gpu_states, has_nvidia_smi
43-
from tritonbench.utils.path_utils import REPO_PATH
44-
from tritonbench.utils.run_utils import get_github_env, get_run_env
45-
46-
repo_locs = {
47-
"tritonbench": REPO_PATH,
48-
}
49-
if args.ci and "TRITONBENCH_TRITON_REPO_PATH" in os.environ:
50-
repo_locs["triton"] = os.environ.get("TRITONBENCH_TRITON_REPO_PATH", None)
51-
repo_locs["pytorch"] = os.environ.get("TRITONBENCH_PYTORCH_REPO_PATH", None)
52-
aggregated_obj = {
53-
"name": "nightly",
54-
"env": get_run_env(run_timestamp, repo_locs),
55-
"metrics": {},
56-
}
57-
if has_nvidia_smi():
58-
aggregated_obj.update(
59-
{
60-
"nvidia_gpu_states": get_nvidia_gpu_states(),
61-
}
62-
)
63-
64-
# Collecting GitHub environment variables when running in CI environment
65-
if args.ci:
66-
aggregated_obj["github"] = get_github_env()
67-
68-
for result_json_file in output_files:
69-
logger.info(f"Loading output file: {result_json_file}.")
70-
result_json_filename = Path(result_json_file).stem
71-
if (
72-
not os.path.exists(result_json_file)
73-
or os.path.getsize(result_json_file) == 0
74-
):
75-
aggregated_obj["metrics"][f"tritonbench_{result_json_filename}-pass"] = 0
76-
continue
77-
# TODO: check if all inputs pass
78-
aggregated_obj["metrics"][f"tritonbench_{result_json_filename}-pass"] = 1
79-
with open(
80-
result_json_file,
81-
"r",
82-
) as fp:
83-
result_obj = json.load(fp)
84-
aggregated_obj["metrics"].update(result_obj)
85-
result_json_path = os.path.join(output_dir, "result.json")
86-
with open(result_json_path, "w") as fp:
87-
json.dump(aggregated_obj, fp, indent=4)
88-
return result_json_path
89-
90-
9141
def get_operator_benchmarks() -> Dict[str, Any]:
9242
def _load_benchmarks(config_path: str) -> Dict[str, Any]:
9343
out = {}
@@ -111,12 +61,17 @@ def _load_benchmarks(config_path: str) -> Dict[str, Any]:
11161

11262
def run():
11363
parser = argparse.ArgumentParser()
64+
parser.add_argument("--name", default="nightly", help="Benchmark name.")
11465
parser.add_argument(
11566
"--ci", action="store_true", help="Running in GitHub Actions CI mode."
11667
)
68+
parser.add_argument(
69+
"--log-scuba", action="store_true", help="Upload results to Scuba."
70+
)
11771
args = parser.parse_args()
11872
setup_tritonbench_cwd()
11973
from tritonbench.utils.run_utils import run_in_task, setup_output_dir
74+
from tritonbench.utils.scuba_utils import decorate_benchmark_data, log_benchmark
12075

12176
run_timestamp, output_dir = setup_output_dir("nightly")
12277
# Run each operator
@@ -127,10 +82,32 @@ def run():
12782
output_file = output_dir.joinpath(f"{op_bench}.json")
12883
op_args.extend(["--output-json", str(output_file.absolute())])
12984
run_in_task(op=op_name, op_args=op_args, benchmark_name=op_bench)
85+
# write pass or fail to result json
86+
# todo: check every input shape has passed
87+
output_file_name = Path(output_file).stem
88+
if not os.path.exists(output_file) or os.path.getsize(output_file) == 0:
89+
logger.warning(f"[nightly] Failed to run {output_file_name}.")
90+
with open(output_file, "w") as f:
91+
json.dump({f"tritonbench_{output_file_name}-pass": 0}, f)
92+
else:
93+
with open(output_file, "r") as f:
94+
obj = json.load(f)
95+
obj[f"tritonbench_{output_file_name}-pass"] = 1
96+
with open(output_file, "w") as f:
97+
json.dump(obj, f, indent=4)
13098
output_files.append(output_file)
13199
# Reduce all operator CSV outputs to a single output json
132-
result_json_file = reduce(run_timestamp, output_dir, output_files, args)
100+
benchmark_data = [json.load(open(f, "r")) for f in output_files]
101+
aggregated_obj = decorate_benchmark_data(
102+
args.name, run_timestamp, args.ci, benchmark_data
103+
)
104+
result_json_file = os.path.join(output_dir, "result.json")
105+
with open(result_json_file, "w") as fp:
106+
json.dump(aggregated_obj, fp, indent=4)
133107
logger.info(f"[nightly] logging result json file to {result_json_file}.")
108+
if args.log_scuba:
109+
log_benchmark(aggregated_obj)
110+
logger.info(f"[nightly] logging results to scuba.")
134111

135112

136113
if __name__ == "__main__":

run.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
import argparse
99
import os
10-
import shlex
1110
import sys
12-
from typing import List, Tuple
11+
import time
12+
from datetime import datetime
13+
from typing import List
1314

1415
from tritonbench.operator_loader import get_op_loader_bench_cls_by_name, is_loader_op
1516

@@ -35,6 +36,7 @@
3536

3637

3738
def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorResult:
39+
run_timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d%H%M%S")
3840
if is_loader_op(args.op):
3941
Opbench = get_op_loader_bench_cls_by_name(args.op)
4042
else:
@@ -72,6 +74,13 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorRe
7274
if "triton_type" in args:
7375
kwargs["triton_type"] = args.triton_type
7476
log_benchmark(**kwargs)
77+
# Log benchmark output to scuba even if not in fbcode
78+
if args.log_scuba and not is_fbcode():
79+
from tritonbench.utils.scuba_utils import log_benchmark
80+
81+
log_benchmark(
82+
benchmark_data=None, run_timestamp=run_timestamp, opbench=opbench
83+
)
7584

7685
if args.plot:
7786
try:

tritonbench/utils/git_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import subprocess
33
from datetime import datetime
4-
from typing import Optional
54

65

76
def get_branch(repo: str, commit: str) -> str:
@@ -11,6 +10,8 @@ def get_branch(repo: str, commit: str) -> str:
1110
If a commit does not belong to any branch, return "unknown"
1211
If a commit belongs to many branches, return the very first branch.
1312
"""
13+
if repo == "unknown":
14+
return "unknown"
1415
assert os.path.exists(repo), f"{repo} path does not exist."
1516
cmd = ["git", "branch", "-a", "--contains", commit, "--no-color"]
1617
branch_names = subprocess.check_output(cmd, cwd=repo).decode().strip().splitlines()
@@ -27,6 +28,8 @@ def get_commit_time(repo: str, commit: str) -> str:
2728
commit: hash of a commit
2829
If a commit does not exist, return "unknown"
2930
"""
31+
if repo == "unknown":
32+
return "unknown"
3033
assert os.path.exists(repo), f"{repo} path does not exist."
3134
git_date_cmd = ["git", "show", "--no-patch", "--format=%ci", commit]
3235
git_date = subprocess.check_output(git_date_cmd, cwd=repo).decode().strip()
@@ -40,6 +43,8 @@ def get_commit_time(repo: str, commit: str) -> str:
4043
def get_current_hash(repo: str) -> str:
4144
"""Get the HEAD hash of a git repo.
4245
repo: local git repo path"""
46+
if repo == "unknown":
47+
return "unknown"
4348
cmd = ["git", "rev-parse", "--verify", "HEAD"]
4449
output = subprocess.check_output(cmd, cwd=repo).decode().strip()
4550
return output

tritonbench/utils/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,9 @@ def get_parser(args=None):
288288
help="Configuration B for A/B testing. Specify operator-specific arguments as a string. "
289289
"Example: '--side-b \"--dynamic\"'",
290290
)
291+
parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.")
291292

292293
if is_fbcode():
293-
parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.")
294294
parser.add_argument(
295295
"--production-shapes",
296296
action="store_true",

tritonbench/utils/run_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def get_run_env(
5757
run_env["pytorch_commit"] = torch.version.git_version
5858
# we assume Tritonbench CI will properly set Triton commit hash in env
5959
run_env["triton_commit"] = os.environ.get(
60-
"TRITONBENCH_TRITON_MAIN_COMMIT", "unknown"
60+
"TRITONBENCH_TRITON_COMMIT_HASH", get_current_hash(repo_locs["triton"])
6161
)
62-
run_env["tritonbench_commit"] = get_current_hash(REPO_PATH)
62+
run_env["tritonbench_commit"] = get_current_hash(repo_locs["tritonbench"])
6363
for repo in ["triton", "pytorch", "tritonbench"]:
6464
repo_loc = repo_locs.get(repo, None)
6565
if not run_env[f"{repo}_commit"] == "unknown" and repo_loc:

0 commit comments

Comments
 (0)