Skip to content

Commit ae51ff5

Browse files
authored
[power] Add initial power chart support (meta-pytorch#420)
1 parent 3e53432 commit ae51ff5

File tree

9 files changed

+155
-2
lines changed

9 files changed

+155
-2
lines changed

.github/workflows/_linux-test-h100.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ jobs:
3030
- name: Install Tritonbench
3131
run: |
3232
# speedup install and skip compile by reusing the docker .so files
33+
. "${SETUP_SCRIPT}"
3334
mkdir -p /workspace/tritonbench/.data
3435
ln -s /workspace/tritonbench/.data .
36+
pip install -r requirements.txt
3537
- name: Test Tritonbench operators on H100 GPU
3638
run: |
3739
bash ./.ci/tritonbench/test-gpu.sh

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
packaging
2-
pynvml
2+
nvidia-ml-py
33
psutil
44
tabulate
5+
matplotlib
56
transformers==4.46.1

test/test_gpu/skip_tests_h100_pytorch.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ fp8_attention:
2626
# fp8_fused_quant_gemm_rowwise requires fb-only kernels
2727
fp8_fused_quant_gemm_rowwise:
2828
gemm:
29+
# torch._inductor.exc.InductorError: LoweringException
30+
# NoValidChoicesError: No choices to select.
31+
- pt2_cutlass_matmul
2932
# internal only kernels
3033
- hstu_triton_matmul
3134
# jagged tests are slow, so disable them in OSS

test/test_gpu/skip_tests_h100_triton_main.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ gemm:
3030
- hstu_triton_matmul
3131
# No need to test cutlass on triton main
3232
- pt2_cutlass_matmul
33+
int4_gemm:
3334
# jagged tests are slow, so disable them in OSS
3435
jagged_layer_norm:
3536
jagged_mean:
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .chart import power_chart_begin, power_chart_end
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import csv
2+
import logging
3+
import os
4+
import signal
5+
import subprocess
6+
import time
7+
8+
import matplotlib.pyplot as plt
9+
import torch
10+
11+
# query every 10 ms
12+
QUERY_FREQUENCY = 10
13+
QUERY_STDOUT_FILE = "power.csv"
14+
QUERY_STDERR_FILE = "power.log"
15+
QUERY_COMMAND = """nvidia-smi -lms {QUERY_FREQUENCY} -i {QUERY_DEVICE} --query-gpu=power.draw.average,power.draw.instant,power.max_limit,temperature.gpu,temperature.memory,clocks.current.sm,clocks.current.memory,clocks_throttle_reasons.hw_thermal_slowdown,clocks_throttle_reasons.sw_thermal_slowdown --format=csv,nounits"""
16+
global QUERY_PROC
17+
global POWER_OUTPUT_DIR
18+
19+
QUERY_PROC = None
20+
POWER_OUTPUT_DIR = None
21+
22+
logger = logging.getLogger(__name__)
23+
logger.setLevel(logging.INFO)
24+
25+
26+
def _get_cuda_device_id():
27+
return torch.cuda.current_device()
28+
29+
30+
def _gen_power_charts(benchmark_name: str, device_name: str, power_csv_file: str):
31+
# Read CSV
32+
with open(power_csv_file) as f:
33+
reader = csv.reader(f)
34+
header = next(reader) # first row as header
35+
header = [col.strip() for col in header]
36+
data = {col: [] for col in header}
37+
38+
for row in reader:
39+
for col, value in zip(header, row):
40+
if value == "[N/A]":
41+
logger.warning(
42+
f"[tritonbench][power] {col} is not available, skipping"
43+
)
44+
value = 0.0
45+
else:
46+
value = (
47+
float(value)
48+
if col
49+
not in [
50+
"clocks_event_reasons.hw_thermal_slowdown",
51+
"clocks_event_reasons.sw_thermal_slowdown",
52+
]
53+
else value
54+
)
55+
data[col].append(value)
56+
57+
# Generate synthetic time axis (100 ms per sample)
58+
n_samples = len(next(iter(data.values())))
59+
time = [i * 0.1 for i in range(n_samples)] # seconds (0.1s = 100 ms)
60+
61+
# Plot power chart
62+
plt.figure(figsize=(10, 6))
63+
for power_col in header[:3]:
64+
plt.plot(time, data[power_col], label=power_col)
65+
plt.xlabel("Time (s)")
66+
plt.ylabel("Power (W)")
67+
plt.legend()
68+
plt.title(
69+
f"[tritonbench] {benchmark_name} power consumption over time on {device_name}"
70+
)
71+
plt.savefig(
72+
os.path.join(POWER_OUTPUT_DIR, "power.png"), dpi=300, bbox_inches="tight"
73+
)
74+
# Plot temp chart
75+
plt.figure(figsize=(10, 6))
76+
for temp_col in header[3:5]:
77+
plt.plot(time, data[temp_col], label=temp_col)
78+
plt.xlabel("Time (s)")
79+
plt.ylabel("Temperature (C)")
80+
plt.legend()
81+
plt.title(f"[tritonbench] {benchmark_name} temperature over time on {device_name}")
82+
plt.savefig(
83+
os.path.join(POWER_OUTPUT_DIR, "temp.png"), dpi=300, bbox_inches="tight"
84+
)
85+
# Plot frequency chart
86+
plt.figure(figsize=(10, 6))
87+
for temp_col in header[5:7]:
88+
plt.plot(time, data[temp_col], label=temp_col)
89+
plt.xlabel("Time (s)")
90+
plt.ylabel("Frequency (MHz)")
91+
plt.legend()
92+
plt.title(f"[tritonbench] {benchmark_name} frequency over time on {device_name}")
93+
plt.savefig(
94+
os.path.join(POWER_OUTPUT_DIR, "freq.png"), dpi=300, bbox_inches="tight"
95+
)
96+
97+
98+
def power_chart_begin(benchmark_name, output_dir):
99+
# check no other proc is running
100+
global QUERY_PROC, POWER_OUTPUT_DIR
101+
assert QUERY_PROC is None, "Power query process must be None to start a new one"
102+
# clean up the directory
103+
POWER_OUTPUT_DIR = os.path.join(output_dir, benchmark_name)
104+
if not os.path.exists(POWER_OUTPUT_DIR):
105+
os.mkdir(POWER_OUTPUT_DIR)
106+
stdout_file_path = os.path.join(POWER_OUTPUT_DIR, QUERY_STDOUT_FILE)
107+
stderr_file_path = os.path.join(POWER_OUTPUT_DIR, QUERY_STDERR_FILE)
108+
# Run the command
109+
query_cmd = QUERY_COMMAND.format(
110+
QUERY_FREQUENCY=QUERY_FREQUENCY, QUERY_DEVICE=_get_cuda_device_id()
111+
).split(" ")
112+
with open(stdout_file_path, "w") as stdout_file, open(
113+
stderr_file_path, "w"
114+
) as stderr_file:
115+
QUERY_PROC = subprocess.Popen(
116+
query_cmd, stdout=stdout_file, stderr=stderr_file, start_new_session=True
117+
)
118+
119+
120+
def power_chart_end():
121+
global QUERY_PROC, POWER_OUTPUT_DIR
122+
assert QUERY_PROC is not None, "Power query process cannot be None"
123+
# Kill the process
124+
QUERY_PROC.send_signal(signal.SIGINT)
125+
time.sleep(0.2)
126+
assert (
127+
QUERY_PROC.poll() is not None
128+
), "Power query process must be killed to proceed"
129+
# generate the chart based on csv
130+
stdout_file_path = os.path.join(POWER_OUTPUT_DIR, QUERY_STDOUT_FILE)
131+
benchmark_name = os.path.basename(POWER_OUTPUT_DIR)
132+
device_name = torch.cuda.get_device_name(_get_cuda_device_id())
133+
_gen_power_charts(benchmark_name, device_name, stdout_file_path)
134+
logger.warning(f"[tritonbench][power] Power chart saved to {POWER_OUTPUT_DIR}.")

tritonbench/operators/softmax/operator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def parse_op_args(args: List[str]):
4242

4343

4444
class Operator(BenchmarkOperator):
45+
DEFAULT_PRECISION = "fp16"
4546
is_compute_bound = False
4647

4748
def __init__(

tritonbench/utils/parser.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,12 @@ def get_parser(args=None):
180180
default=None,
181181
help="Dump Triton IR to specific directory.",
182182
)
183+
parser.add_argument(
184+
"--power-chart",
185+
type=str,
186+
default=None,
187+
help="Dump GPU power chart to specific directory.",
188+
)
183189
parser.add_argument(
184190
"--gpu-lockdown",
185191
action="store_true",

tritonbench/utils/triton_op.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@
2828
import tabulate
2929
import torch
3030
import triton
31-
from torch.utils._pytree import tree_flatten, tree_map
3231

3332
from tritonbench.components.do_bench import do_bench_wrapper, Latency
3433
from tritonbench.components.export import export_data
3534

35+
from tritonbench.components.power.chart import power_chart_begin, power_chart_end
3636
from tritonbench.utils.constants import (
3737
DEFAULT_QUANTILES,
3838
DEFAULT_REP,
@@ -873,6 +873,8 @@ def run(
873873
) -> None:
874874
"""Benchmarking the operator and returning its metrics."""
875875
metrics = []
876+
if self.tb_args.power_chart:
877+
power_chart_begin(self.benchmark_name, self.tb_args.power_chart)
876878
try:
877879
if "proton" in self.required_metrics:
878880
import triton.profiler as proton
@@ -998,6 +1000,8 @@ def _reduce_benchmarks(acc, bm_name: str):
9981000
os._exit(1)
9991001
raise
10001002
finally:
1003+
if self.tb_args.power_chart:
1004+
power_chart_end()
10011005
self.output = BenchmarkOperatorResult(
10021006
benchmark_name=self.tb_args.benchmark_name,
10031007
op_name=self.name,

0 commit comments

Comments
 (0)