Skip to content

Commit 00cf53f

Browse files
authored
[BENCH] Incorporate EP sharding and deprecate the legacy communication (#8493)
TP > 1 is not supported in this mode
1 parent 1c72fb6 commit 00cf53f

File tree

3 files changed

+91
-514
lines changed

3 files changed

+91
-514
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
6969
x_dtype = torch.float8_e4m3fnuz
7070

7171
input_x = torch.randn((batch // DP, dim1), device=dev)
72+
expt_assignment = triton_dist.create_expt_assignment(EP, n_expts_tot, torch.device(dev))
7273
# run layer
7374
fpath = Path(tempfile.mktemp())
7475
proton.start(str(fpath), hook="triton")
@@ -78,15 +79,15 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
7879
if n_expts_tot > 1: # sparse
7980
logits = matmul_ogs(xg, wg, bg, precision_config=pcg)
8081
x, rdata, gather_indx, scatter_indx, metadata = triton_dist.routing(input_x, logits, n_expts_act, EP=EP,
81-
TP=TP)
82+
TP=TP, expt_assignment=expt_assignment)
8283
else: # dense
8384
x = triton_dist.all_gather(input_x, dim=0)
8485
rdata, gather_indx, scatter_indx, metadata = None, None, None, None
8586
if x.nelement() > 0:
8687
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act)
8788
x = matmul_ogs(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=scatter_indx,
8889
precision_config=pc2)
89-
x = triton_dist.reduce_scatter(x, metadata=metadata, dim=0)
90+
x = triton_dist.reduce_scatter(x, n_expts_act, metadata=metadata, expt_assignment=expt_assignment)
9091
proton.finalize()
9192
return roofline.parse_profile(fpath.with_suffix(".hatchet"), useful_op_regex=".*matmul.*")
9293

@@ -136,6 +137,8 @@ def roofline_mlp(batch_sizes, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
136137
parser.add_argument("--name", type=str, choices=["dense", "gpt-oss-x2"])
137138
parser.add_argument("--quantized", action="store_true", default=False)
138139
args = parser.parse_args()
140+
if args.tp > 1:
141+
raise NotImplementedError("TP>1 is not supported yet in distributed mode.")
139142
dtypes = quantized_dtypes if args.quantized else dense_dtypes
140143
if args.name == "dense":
141144
assert args.ep == 1, "EP must be 1 for dense"

0 commit comments

Comments
 (0)