Skip to content

Commit 139ce65

Browse files
authored
[triton_kernels] refactor roofline plotting code (triton-lang#7915)
1 parent 2fafd63 commit 139ce65

File tree

4 files changed

+278
-159
lines changed

4 files changed

+278
-159
lines changed
Lines changed: 45 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -1,98 +1,36 @@
1+
from itertools import chain
12
from pathlib import Path
23
from copy import deepcopy
3-
import matplotlib.pyplot as plt
44
import triton.profiler as proton
5-
from triton.profiler import viewer
65
import torch
76
import argparse
87
import triton_kernels
98
import triton_kernels.swiglu
109
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
11-
from triton_kernels.target_info import is_hip, get_cdna_version
12-
from dataclasses import dataclass
10+
from triton_kernels.target_info import get_cdna_version
1311
import distributed as triton_dist
1412
from triton_kernels.tensor_details import layout
1513
from bench_utils import quantize_weight
14+
import tempfile
15+
import roofline
1616

17-
if torch.cuda.is_available() and not is_hip():
18-
from triton._C.libtriton import nvidia
1917

20-
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
21-
cublas = nvidia.cublas.CublasLt(cublas_workspace)
22-
else:
23-
cublas = None
24-
25-
26-
@dataclass
27-
class PerfData:
28-
time: float
29-
flops: float
30-
bytes: float
31-
bitwidth: int
32-
device_type: str
33-
device_info: dict
34-
35-
@property
36-
def tflops(self):
37-
return self.flops / self.time * 1e-3
38-
39-
@property
40-
def tbps(self):
41-
return self.bytes / self.time * 1e-3
42-
43-
@property
44-
def opint(self):
45-
# operational intensity
46-
assert self.bytes > 0
47-
return self.flops / self.bytes
48-
49-
@property
50-
def max_tbps(self):
51-
return (proton.specs.max_bps(
52-
self.device_type,
53-
self.device_info["arch"],
54-
self.device_info["bus_width"],
55-
self.device_info["memory_clock_rate"],
56-
) * 1e-12)
57-
58-
@property
59-
def max_tflops(self):
60-
return (proton.specs.max_flops(
61-
self.device_type,
62-
self.device_info["arch"],
63-
self.bitwidth,
64-
self.device_info["num_sms"],
65-
self.device_info["clock_rate"],
66-
) * 1e-12)
67-
68-
@property
69-
def util(self) -> float:
70-
assert self.bitwidth in (8, 16)
71-
min_t_flop = self.flops / self.max_tflops * 1e-3
72-
min_t_bw = self.bytes / self.max_tbps * 1e-3
73-
return max(min_t_flop, min_t_bw) / self.time
74-
75-
76-
def get_bench_path(name, rank, x_dtype, w_dtype, TP, EP):
77-
return Path(f"logs/{name}/{rank}/{x_dtype}-{w_dtype}-TP{TP}-EP{EP}/")
78-
79-
80-
def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, name):
18+
def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP):
8119
assert n_expts_tot % EP == 0
8220
assert dim2 % TP == 0
8321
rank, world_size = triton_dist.setup()
8422
dev = f"cuda:{rank}"
8523
DP = world_size
24+
batch = batch_per_expt * n_expts_tot // n_expts_act
8625

8726
assert n_expts_tot % EP == 0, f"{n_expts_tot=}, {EP=}, n_expts_tot must be divisible by EP"
8827
assert dim2 % TP == 0, f"{dim2=}, {TP=}, dim2 must be divisible by TP"
8928

90-
# input
29+
# -- init data --
9130
# weights
9231
wg = triton_dist.broadcast(torch.randn((dim1, n_expts_tot), device=dev))
9332
w1 = torch.randn((n_expts_tot // EP, dim1, dim2 // TP), device=dev)
9433
w2 = torch.randn((n_expts_tot // EP, dim2 // TP // 2, dim1), device=dev)
95-
9634
# biases
9735
bg = triton_dist.broadcast(torch.randn((n_expts_tot, ), device=dev))
9836
b1 = torch.randn((n_expts_tot // EP, dim2 // TP), device=dev)
@@ -125,16 +63,15 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
12563
pc2 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), weight_scale=w2_scale)
12664

12765
# -- benchmark --
128-
fpath = get_bench_path(name, rank, x_dtype, w_dtype, TP, EP) / f"profiles/batch-{batch}.hatchet"
129-
fpath.parent.mkdir(parents=True, exist_ok=True)
13066
x_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype]
13167
# special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz
13268
if x_dtype == torch.float8_e4m3fn and get_cdna_version() == 3:
13369
x_dtype = torch.float8_e4m3fnuz
13470

13571
input_x = torch.randn((batch // DP, dim1), device=dev)
13672
# run layer
137-
proton.start(str(fpath.with_suffix("")), hook="triton")
73+
fpath = Path(tempfile.mktemp())
74+
proton.start(str(fpath), hook="triton")
13875
input_x = input_x.to(x_dtype)
13976
xg = input_x.to(wg.dtype if n_expts_tot > 1 else input_x.dtype)
14077
for i in range(100):
@@ -151,114 +88,66 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
15188
precision_config=pc2)
15289
x = triton_dist.reduce_scatter(x, metadata=metadata, dim=0)
15390
proton.finalize()
154-
155-
# -- analyze --
156-
gf, _, _, info = viewer.read(fpath)
157-
# Now the dataframe only contains leave nodes (i.e., kernels) that perform matmuls
158-
matmuls = gf.filter("MATCH ('*', c) WHERE c.'name' =~ '.*matmul.*' AND c IS LEAF").dataframe
159-
bytes = matmuls["bytes"].sum()
160-
flops = sum(matmuls[[c for c in ["flops8", "flops16"] if c in matmuls.columns]].sum())
161-
# Compute total time (incl. "not useful" work)
162-
time = gf.filter("MATCH ('*', c) WHERE c IS LEAF").dataframe["time (ns)"].sum()
163-
device_type = matmuls["device_type"].iloc[0]
164-
device_id = matmuls["device_id"].iloc[0]
165-
device_info = info[device_type][device_id]
166-
return PerfData(
167-
time=time,
168-
flops=flops,
169-
bytes=bytes,
170-
bitwidth=x.dtype.itemsize * 8,
171-
device_type=device_type,
172-
device_info=device_info,
173-
)
91+
return roofline.parse_profile(fpath.with_suffix(".hatchet"), useful_op_regex=".*matmul.*")
17492

17593

176-
def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP=1, EP=1, name="",
177-
verbose=True):
178-
from itertools import chain
179-
from bisect import bisect_left
94+
def roofline_mlp(batch_sizes, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, \
95+
name="", verbose=True):
96+
out_path = Path(f"logs/{name}/{x_dtype}x-{w_dtype}w-TP{TP}-EP{EP}/")
97+
out_path.mkdir(parents=True, exist_ok=True)
98+
csv_path = roofline.compute_roofline(dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, # fixed args
99+
bench_fn=bench_mlp, # function to benchmark
100+
intensity_proxy_name="batch_per_expt", # intensity proxy name
101+
intensity_proxy_values=batch_sizes, # intensity proxy values to sweep
102+
verbose=verbose, # options
103+
out_path=out_path.with_suffix(".csv")) # output path
104+
png_path = roofline.plot_roofline(series=[csv_path], # roofline data to plot
105+
flops_dtype=x_dtype, # dtype to use for FLOPS roof
106+
xlabel="batch_per_expt", title=out_path, # plot option
107+
out_path=out_path.with_suffix(".png"), # output path
108+
max_tbps="memset", max_tflops="cublas") # hardware limits
180109

181-
batches = list(chain(*[range(*r) for r in batch_ranges]))
182-
# collect performance data
183-
perfs = []
184-
bench_case = f"{name} ({x_dtype}x{w_dtype}, TP={TP}, EP={EP})"
185-
print(f"Benchmarking {bench_case}...")
186-
print("===============================================================")
187-
for batch in batches:
188-
perfs += [bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, name)]
189-
if verbose:
190-
print(f"Batch: {batch}; Util: {perfs[-1].util}; TFLOPS: {perfs[-1].tflops}; TBPS: {perfs[-1].tbps}")
191-
print("===============================================================")
192-
# machine limits
193-
max_tbps = perfs[0].max_tbps
194-
max_tflops = perfs[0].max_tflops
195-
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
196-
ax.set_xlabel("batch size (toks/expt)")
197-
ax.set_ylabel("performance [TFLOP/s]")
198-
ax.set_title(f"{bench_case} roofline")
199-
# add a tiny margin so points are not flush with the frame
200-
xs = [batch * n_expts_act / n_expts_tot for batch in batches]
201-
perf = [p.tflops for p in perfs]
202-
xmin, xmax = min(xs), max(xs)
203-
dx = 0.05 * (xmax - xmin) if xmax > xmin else 1.0
204-
ax.set_xlim(xmin - dx, xmax + dx)
205-
ax.set_ylim(100, max_tflops + 500)
206-
# plot roofline
207-
opints = [p.opint for p in perfs]
208-
knee = bisect_left(opints, max_tflops / max_tbps)
209-
if knee > 0: # has a bandwidth-bound knee
210-
x_bw = [xs[0], xs[knee - 1]]
211-
y_bw = [opints[0] * max_tbps, max_tflops]
212-
else: # no knee found, compute-bound only
213-
x_bw = y_bw = []
214-
x_comp = xs[knee:]
215-
y_comp = [max_tflops] * len(x_comp)
216-
ax.plot(x_bw, y_bw, "--", label=f"BW-bound ({max_tbps:.1f} TB/s)", color="blue")
217-
ax.plot(x_comp, y_comp, "--", label=f"Compute-bound ({max_tflops:.0f} TFLOP/s)", color="orange")
218-
# plot data
219-
ax.scatter(xs, perf, marker="+")
220-
ax.legend(frameon=False, loc="lower right")
221-
ax.grid(True, which="both", ls=":", lw=0.5)
222-
fig.tight_layout()
223-
rank, _ = triton_dist.setup()
224-
fpath = get_bench_path(name, rank, x_dtype, w_dtype, TP, EP) / "roofline.png"
225-
plt.savefig(fpath)
110+
return png_path
226111

227112

228113
if __name__ == "__main__":
229114
has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4
230-
batch_ranges_dense = [(1024, 32768, 1024)]
231-
batch_ranges_moe = [(128, 512, 32), (512, 32000, 128)]
115+
batch_sizes_dense = [(128, 8192, 128)]
116+
batch_ranges_moe = [(2**(2 + k), 2**(3 + k), min(2**k, 32)) for k in range(8)]
117+
batch_sizes_moe = list(chain(*[range(*r) for r in batch_ranges_moe]))
232118
dense_dtypes = ["fp8", "fp8"]
233119
quantized_dtypes = ["fp8", "mx4"] if has_native_mx4 else ["bf16", "mx4"]
234120
rank, world_size = triton_dist.setup()
235121
if world_size > 1:
236122
# Running all workloads at once may cause OOM on some GPUs such as H100 80GB.
237123
# Thus we request users to run each workload separately.
238124
# For example, all eligible combinations of options are listed below when four GPUs are used:
239-
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 2 --ep 2 --name llama4-maverick
240-
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 1 --ep 4 --name llama4-maverick
241-
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name llama4-maverick
125+
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 2 --ep 2 --name gpt-oss-x2
126+
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 1 --ep 4 --name gpt-oss-x2
127+
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name gpt-oss-x2
242128
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name dense
243-
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 2 --ep 2 --name llama4-maverick --quantized
244-
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 1 --ep 4 --name llama4-maverick --quantized
245-
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name llama4-maverick --quantized
129+
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 2 --ep 2 --name gpt-oss-x2 --quantized
130+
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 1 --ep 4 --name gpt-oss-x2 --quantized
131+
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name gpt-oss-x2 --quantized
246132
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name dense --quantized
247133
argparse = argparse.ArgumentParser()
248134
argparse.add_argument("--tp", type=int, default=1)
249135
argparse.add_argument("--ep", type=int, default=1)
250-
argparse.add_argument("--name", type=str, choices=["dense", "llama4-maverick"])
136+
argparse.add_argument("--name", type=str, choices=["dense", "gpt-oss-x2"])
251137
argparse.add_argument("--quantized", action="store_true", default=False)
252138
args = argparse.parse_args()
253139
dtypes = dense_dtypes if args.quantized else quantized_dtypes
254140
if args.name == "dense":
255141
assert args.ep == 1, "EP must be 1 for dense"
256-
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *dtypes, TP=args.tp, EP=args.ep, name="dense")
142+
roofline_mlp(batch_sizes_dense, 8192, 8192, 1, 1, *dtypes, TP=args.tp, EP=args.ep, name="dense")
257143
else:
258-
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *dtypes, TP=args.tp, EP=args.ep, name="llama4-maverick")
144+
roofline_mlp(batch_sizes_moe, 5760, 5760, 128, 4, *dtypes, TP=args.tp, EP=args.ep, name="gpt-oss-x2")
259145
triton_dist.cleanup()
260146
else:
261-
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *dense_dtypes, TP=1, EP=1, name="dense")
262-
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *quantized_dtypes, TP=1, EP=1, name="dense")
263-
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick")
264-
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick")
147+
pass
148+
# roofline_mlp(batch_sizes_dense, 8192, 8192, 1, 1, *quantized_dtypes, TP=1, EP=1, name="dense")
149+
# roofline_mlp(batch_sizes_moe, 5760, 5760, 128, 4, *dense_dtypes, TP=1, EP=1, name="gpt-oss-x2")
150+
roofline_mlp(batch_sizes_moe, 5760, 5760, 128, 4, *quantized_dtypes, TP=1, EP=1, name="gpt-oss-x2")
151+
# roofline_mlp(batch_sizes_moe, 5760, 5760, 128, 4, *quantized_dtypes, TP=2, EP=1, name="gpt-oss-x2")
152+
# roofline_mlp(batch_sizes_moe, 5760, 5760, 128, 4, *quantized_dtypes, TP=4, EP=1, name="gpt-oss-x2")
153+
# roofline_mlp(batch_ranges_moe, 5760, 5760, 128, 4, *quantized_dtypes, TP=8, EP=1, name="gpt-oss-x2")

0 commit comments

Comments
 (0)