Skip to content

Commit 5045b79

Browse files
authored
[KERNELS][MULTI-GPU] Initialize simple multi-gpu moe baseline (#7352)
1 parent d7e43ad commit 5045b79

File tree

8 files changed

+698
-80
lines changed

8 files changed

+698
-80
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,9 @@ jobs:
139139
cd ../../triton_kernels/
140140
python3 -m pytest -s -n 12 tests/
141141
fi
142-
142+
- name: Run distributed tests
143+
run: |
144+
make test-distributed
143145
- name: Run asan tests on AMD
144146
if: false
145147
run: |

.github/workflows/integration-tests-nvidia.yml

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,16 @@ jobs:
6868
- name: Update PATH
6969
run: |
7070
echo "$HOME/.local/bin" >> $GITHUB_PATH
71+
- name: Setup Python environment for GB200
72+
if: ${{ matrix.runner[0] == 'nvidia-gb200' }}
73+
run: |
74+
echo "/venv/bin" >> $GITHUB_PATH
75+
echo "VIRTUAL_ENV=/venv" >> $GITHUB_ENV
76+
echo "PYTHONHOME=" >> $GITHUB_ENV
7177
- name: Install Triton
7278
env:
7379
CUDA_HOME: "/usr/local/cuda"
7480
run: |
75-
if [ "${{ matrix.runner[0] }}" == "nvidia-gb200" ]; then
76-
source /venv/bin/activate
77-
fi
7881
nproc
7982
nvidia-smi
8083
echo "PATH is '$PATH'"
@@ -85,20 +88,14 @@ jobs:
8588
- name: Run lit tests
8689
run: make test-lit
8790
- name: Run python tests on CUDA
88-
run: |
89-
if [ "${{ matrix.runner[0] }}" == "nvidia-gb200" ]; then
90-
source /venv/bin/activate
91-
fi
92-
make NUM_PROCS=24 test-unit
91+
run: make NUM_PROCS=24 test-unit
92+
- name: Run distributed tests
93+
run: make test-distributed
9394
- name: Run interpreter tests
9495
if: ${{ matrix.runner[0] == 'nvidia-h100' }}
9596
run: make test-interpret
9697
- name: Run regression tests
97-
run: |
98-
if [ "${{ matrix.runner[0] }}" == "nvidia-gb200" ]; then
99-
source /venv/bin/activate
100-
fi
101-
make test-regression
98+
run: make test-regression
10299
- name: Run C++ unittests
103100
run: make test-cpp
104101
- name: Run Proton tests

Makefile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ test-unit: all
4444
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
4545
$(PYTEST) -s -n $(NUM_PROCS) python/test/gluon
4646

47+
.PHONY: test-distributed
48+
test-distributed: all
49+
$(PYTHON) -m pip install --upgrade pip
50+
$(PYTHON) -m pip install python/triton_kernels -v
51+
$(PYTEST) -s python/triton_kernels/bench/distributed.py
52+
4753
.PHONY: test-gluon
4854
test-gluon: all
4955
$(PYTEST) -s -n $(NUM_PROCS) python/test/gluon

python/triton/language/standard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,9 @@ def _or_combine(x, y):
317317

318318
@core._tensor_member_fn
319319
@jit
320-
@core._add_reduction_docstr("reduce_of")
320+
@core._add_reduction_docstr("reduce_or")
321321
def reduce_or(input, axis, keep_dims=False):
322-
core.static_assert(input.type.scalar.is_int(), "reduce_of only supported for integers")
322+
core.static_assert(input.type.scalar.is_int(), "reduce_or only supported for integers")
323323
return core.reduce(input, axis, _or_combine, keep_dims=keep_dims)
324324

325325

python/triton_kernels/bench/bench_mlp.py

Lines changed: 101 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,25 @@
44
import triton.profiler as proton
55
from triton.profiler import viewer
66
import torch
7+
import argparse
78
import triton_kernels
89
import triton_kernels.swiglu
9-
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
1010
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
11-
from triton_kernels.numerics import InFlexData
12-
from triton_kernels.routing import routing
13-
from triton_kernels.target_info import is_cuda, is_hip, get_cdna_version, cuda_capability_geq
14-
from triton_kernels.tensor import convert_layout
15-
from triton_kernels.tensor import wrap_torch_tensor, FP4
11+
from triton_kernels.target_info import is_hip, get_cdna_version
1612
from dataclasses import dataclass
13+
import distributed as triton_dist
1714
from triton_kernels.tensor_details import layout
15+
from bench_utils import quantize_weight
1816

1917
if torch.cuda.is_available() and not is_hip():
2018
from triton._C.libtriton import nvidia
19+
2120
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
2221
cublas = nvidia.cublas.CublasLt(cublas_workspace)
2322
else:
2423
cublas = None
2524

2625

27-
def quantize(w, dtype, **opt):
28-
if dtype == "bf16":
29-
wq = w.to(torch.bfloat16).transpose(-1, -2).contiguous().transpose(-1, -2)
30-
return wq, InFlexData(), None
31-
elif dtype == "fp8":
32-
fp8e4_dtype = torch.float8_e4m3fn if get_cdna_version() != 3 \
33-
else torch.float8_e4m3fnuz
34-
wq = w.to(fp8e4_dtype)
35-
if is_cuda() and not cuda_capability_geq(10, 0):
36-
wq = wq.transpose(-1, -2).contiguous().transpose(-1, -2)
37-
return wq, InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)), None
38-
else:
39-
assert dtype == "mx4", f"{dtype=}"
40-
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
41-
if opt:
42-
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"], **opt["value_layout_opts"])
43-
w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"], **opt["scale_layout_opts"])
44-
return w, InFlexData(), w_scale
45-
46-
4726
@dataclass
4827
class PerfData:
4928
time: float
@@ -69,13 +48,22 @@ def opint(self):
6948

7049
@property
7150
def max_tbps(self):
72-
return proton.specs.max_bps(self.device_type, self.device_info["arch"], self.device_info["bus_width"],
73-
self.device_info["memory_clock_rate"]) * 1e-12
51+
return (proton.specs.max_bps(
52+
self.device_type,
53+
self.device_info["arch"],
54+
self.device_info["bus_width"],
55+
self.device_info["memory_clock_rate"],
56+
) * 1e-12)
7457

7558
@property
7659
def max_tflops(self):
77-
return proton.specs.max_flops(self.device_type, self.device_info["arch"], self.bitwidth,
78-
self.device_info["num_sms"], self.device_info["clock_rate"]) * 1e-12
60+
return (proton.specs.max_flops(
61+
self.device_type,
62+
self.device_info["arch"],
63+
self.bitwidth,
64+
self.device_info["num_sms"],
65+
self.device_info["clock_rate"],
66+
) * 1e-12)
7967

8068
@property
8169
def util(self) -> float:
@@ -85,62 +73,83 @@ def util(self) -> float:
8573
return max(min_t_flop, min_t_bw) / self.time
8674

8775

76+
def get_bench_path(name, rank, x_dtype, w_dtype, TP, EP):
77+
return Path(f"logs/{name}/{rank}/{x_dtype}-{w_dtype}-TP{TP}-EP{EP}/")
78+
79+
8880
def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, name):
8981
assert n_expts_tot % EP == 0
9082
assert dim2 % TP == 0
91-
dev = "cuda"
83+
rank, world_size = triton_dist.setup()
84+
dev = f"cuda:{rank}"
85+
DP = world_size
86+
87+
assert n_expts_tot % EP == 0, f"{n_expts_tot=}, {EP=}, n_expts_tot must be divisible by EP"
88+
assert dim2 % TP == 0, f"{dim2=}, {TP=}, dim2 must be divisible by TP"
9289

9390
# input
9491
# weights
95-
wg = torch.randn((dim1, n_expts_tot), device=dev)
92+
wg = triton_dist.broadcast(torch.randn((dim1, n_expts_tot), device=dev))
9693
w1 = torch.randn((n_expts_tot // EP, dim1, dim2 // TP), device=dev)
9794
w2 = torch.randn((n_expts_tot // EP, dim2 // TP // 2, dim1), device=dev)
95+
9896
# biases
99-
bg = torch.randn((n_expts_tot, ), device=dev)
97+
bg = triton_dist.broadcast(torch.randn((n_expts_tot, ), device=dev))
10098
b1 = torch.randn((n_expts_tot // EP, dim2 // TP), device=dev)
10199
b2 = torch.randn((n_expts_tot // EP, dim1), device=dev)
100+
ep_indx = (rank // TP) % EP
101+
groups = [list(range(ep * TP, (ep + 1) * TP)) for ep in range(EP)]
102+
b2 = triton_dist.broadcast(b2, src=ep_indx * TP, groups=groups, group_idx=ep_indx)
102103

103104
# -- numerics --
104-
optg = dict()
105105
opt1 = dict()
106106
opt2 = dict()
107107
if w_dtype == "mx4" and not is_hip():
108108
num_warps = 4 if batch <= 512 else 8
109109
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
110110
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
111111
mx_axis=1, num_warps=num_warps)
112-
opt1 = {"value_layout": value_layout, "value_layout_opts": value_layout_opts, \
113-
"scale_layout": scale_layout, "scale_layout_opts": scale_layout_opts}
112+
opt1 = {
113+
"value_layout": value_layout,
114+
"value_layout_opts": value_layout_opts,
115+
"scale_layout": scale_layout,
116+
"scale_layout_opts": scale_layout_opts,
117+
}
114118
opt2 = deepcopy(opt1)
115-
wg, wg_flex, wg_scale = quantize(wg, "bf16", **optg)
116-
w1, w1_flex, w1_scale = quantize(w1, w_dtype, **opt1)
117-
w2, w2_flex, w2_scale = quantize(w2, w_dtype, **opt2)
119+
wg, wg_flex, wg_scale = quantize_weight(wg, "bf16")
120+
w1, w1_flex, w1_scale = quantize_weight(w1, w_dtype, **opt1)
121+
w2, w2_flex, w2_scale = quantize_weight(w2, w_dtype, **opt2)
118122
pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), weight_scale=wg_scale)
119123
act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (1.0, 1.0), 2)
120124
pc1 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), weight_scale=w1_scale)
121125
pc2 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), weight_scale=w2_scale)
122126

123127
# -- benchmark --
124-
fpath = Path(f"logs/{name}/{x_dtype}-{w_dtype}-TP{TP}-EP{EP}/profiles/batch-{batch}.hatchet")
128+
fpath = get_bench_path(name, rank, x_dtype, w_dtype, TP, EP) / f"profiles/batch-{batch}.hatchet"
125129
fpath.parent.mkdir(parents=True, exist_ok=True)
126130
x_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype]
127131
# special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz
128132
if x_dtype == torch.float8_e4m3fn and get_cdna_version() == 3:
129133
x_dtype = torch.float8_e4m3fnuz
130134

131-
x = torch.randn((batch, dim1), device=dev)
132-
xg = x.to(wg.dtype if n_expts_tot > 1 else x_dtype)
133-
x = x.to(x_dtype)
135+
input_x = torch.randn((batch // DP, dim1), device=dev)
134136
# run layer
135-
proton.start(str(fpath.with_suffix('')), hook="triton")
137+
proton.start(str(fpath.with_suffix("")), hook="triton")
138+
input_x = input_x.to(x_dtype)
139+
xg = input_x.to(wg.dtype if n_expts_tot > 1 else input_x.dtype)
136140
for i in range(100):
137-
if n_expts_tot > 1:
141+
if n_expts_tot > 1: # sparse
138142
logits = matmul_ogs(xg, wg, bg, precision_config=pcg)
139-
rdata, gather_indx, scatter_indx = routing(logits, n_expts_act, simulated_ep=EP)
140-
else:
141-
rdata, gather_indx, scatter_indx = None, None, None
142-
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act)
143-
x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2)
143+
x, rdata, gather_indx, scatter_indx, metadata = triton_dist.routing(input_x, logits, n_expts_act, EP=EP,
144+
TP=TP)
145+
else: # dense
146+
x = triton_dist.all_gather(input_x, dim=0)
147+
rdata, gather_indx, scatter_indx, metadata = None, None, None, None
148+
if x.nelement() > 0:
149+
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act)
150+
x = matmul_ogs(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=scatter_indx,
151+
precision_config=pc2)
152+
x = triton_dist.reduce_scatter(x, metadata=metadata, dim=0)
144153
proton.finalize()
145154

146155
# -- analyze --
@@ -153,14 +162,21 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
153162
device_type = matmuls["device_type"].iloc[0]
154163
device_id = matmuls["device_id"].iloc[0]
155164
device_info = info[device_type][device_id]
156-
return PerfData(time=time, flops=flops, bytes=bytes, bitwidth=x.dtype.itemsize * 8, device_type=device_type,
157-
device_info=device_info)
165+
return PerfData(
166+
time=time,
167+
flops=flops,
168+
bytes=bytes,
169+
bitwidth=x.dtype.itemsize * 8,
170+
device_type=device_type,
171+
device_info=device_info,
172+
)
158173

159174

160175
def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP=1, EP=1, name="",
161176
verbose=True):
162177
from itertools import chain
163178
from bisect import bisect_left
179+
164180
batches = list(chain(*[range(*r) for r in batch_ranges]))
165181
# collect performance data
166182
perfs = []
@@ -198,18 +214,13 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
198214
y_comp = [max_tflops] * len(x_comp)
199215
ax.plot(x_bw, y_bw, "--", label=f"BW-bound ({max_tbps:.1f} TB/s)", color="blue")
200216
ax.plot(x_comp, y_comp, "--", label=f"Compute-bound ({max_tflops:.0f} TFLOP/s)", color="orange")
201-
x_bw, x_comp = xs[:knee], xs[knee:]
202-
x_bw = [x_bw[0], x_comp[0]]
203-
y_bw = [opints[0] * max_tbps, max_tflops]
204-
y_comp = [max_tflops] * len(x_comp)
205-
ax.plot(x_bw, y_bw, "--", label=f"BW-bound ({max_tbps:.1f} TB/s)")
206-
ax.plot(x_comp, y_comp, "--", label=f"Compute-bound ({max_tflops:.0f} TFLOP/s)")
207217
# plot data
208218
ax.scatter(xs, perf, marker="+")
209219
ax.legend(frameon=False, loc="lower right")
210220
ax.grid(True, which="both", ls=":", lw=0.5)
211221
fig.tight_layout()
212-
fpath = Path(f"logs/{name}/{x_dtype}-{w_dtype}-TP{TP}-EP{EP}/roofline.png")
222+
rank, _ = triton_dist.setup()
223+
fpath = get_bench_path(name, rank, x_dtype, w_dtype, TP, EP) / "roofline.png"
213224
plt.savefig(fpath)
214225

215226

@@ -219,7 +230,34 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
219230
batch_ranges_moe = [(128, 512, 32), (512, 32000, 128)]
220231
dense_dtypes = ["fp8", "fp8"]
221232
quantized_dtypes = ["fp8", "mx4"] if has_native_mx4 else ["bf16", "mx4"]
222-
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *dense_dtypes, TP=1, EP=1, name="dense")
223-
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *quantized_dtypes, TP=1, EP=1, name="dense")
224-
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick")
225-
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick")
233+
rank, world_size = triton_dist.setup()
234+
if world_size > 1:
235+
# Running all workloads at once may cause OOM on some GPUs such as H100 80GB.
236+
# Thus we request users to run each workload separately.
237+
# For example, all eligible combinations of options are listed below when four GPUs are used:
238+
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 2 --ep 2 --name llama4-maverick
239+
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 1 --ep 4 --name llama4-maverick
240+
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name llama4-maverick
241+
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name dense
242+
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 2 --ep 2 --name llama4-maverick --quantized
243+
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 1 --ep 4 --name llama4-maverick --quantized
244+
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name llama4-maverick --quantized
245+
# torchrun --nproc-per-node=4 ./bench_mlp.py --tp 4 --ep 1 --name dense --quantized
246+
argparse = argparse.ArgumentParser()
247+
argparse.add_argument("--tp", type=int, default=1)
248+
argparse.add_argument("--ep", type=int, default=1)
249+
argparse.add_argument("--name", type=str, choices=["dense", "llama4-maverick"])
250+
argparse.add_argument("--quantized", action="store_true", default=False)
251+
args = argparse.parse_args()
252+
dtypes = dense_dtypes if args.quantized else quantized_dtypes
253+
if args.name == "dense":
254+
assert args.ep == 1, "EP must be 1 for dense"
255+
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *dtypes, TP=args.tp, EP=args.ep, name="dense")
256+
else:
257+
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *dtypes, TP=args.tp, EP=args.ep, name="llama4-maverick")
258+
triton_dist.cleanup()
259+
else:
260+
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *dense_dtypes, TP=1, EP=1, name="dense")
261+
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *quantized_dtypes, TP=1, EP=1, name="dense")
262+
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick")
263+
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick")
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from triton_kernels.numerics import InFlexData
2+
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
3+
from triton_kernels.tensor import convert_layout
4+
from triton_kernels.tensor import wrap_torch_tensor, FP4
5+
from triton_kernels.target_info import is_cuda, get_cdna_version, cuda_capability_geq
6+
import torch
7+
8+
9+
def quantize_weight(w, dtype, **opt):
10+
if dtype == "bf16":
11+
wq = w.to(torch.bfloat16).transpose(-1, -2).contiguous().transpose(-1, -2)
12+
return wq, InFlexData(), None
13+
elif dtype == "fp8":
14+
fp8e4_dtype = torch.float8_e4m3fn if get_cdna_version() != 3 else torch.float8_e4m3fnuz
15+
wq = w.to(fp8e4_dtype)
16+
if is_cuda() and not cuda_capability_geq(10, 0):
17+
wq = wq.transpose(-1, -2).contiguous().transpose(-1, -2)
18+
return wq, InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)), None
19+
else:
20+
assert dtype == "mx4", f"{dtype=}"
21+
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
22+
if opt:
23+
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"], **opt["value_layout_opts"])
24+
w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"], **opt["scale_layout_opts"])
25+
return w, InFlexData(), w_scale

0 commit comments

Comments
 (0)